encounter1997 commited on
Commit
8d14048
β€’
1 Parent(s): c73a506

add gradio

Browse files
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # custom dirs
2
+ checkpoints/
3
+ outputs/
4
+
5
+ # Initially taken from Github's Python gitignore file
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # tests and logs
16
+ tests/fixtures/cached_*_text.txt
17
+ logs/
18
+ lightning_logs/
19
+ lang_code_data/
20
+
21
+ # Distribution / packaging
22
+ .Python
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ .hypothesis/
61
+ .pytest_cache/
62
+
63
+ # Translations
64
+ *.mo
65
+ *.pot
66
+
67
+ # Django stuff:
68
+ *.log
69
+ local_settings.py
70
+ db.sqlite3
71
+
72
+ # Flask stuff:
73
+ instance/
74
+ .webassets-cache
75
+
76
+ # Scrapy stuff:
77
+ .scrapy
78
+
79
+ # Sphinx documentation
80
+ docs/_build/
81
+
82
+ # PyBuilder
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # celery beat schedule file
96
+ celerybeat-schedule
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # Environments
102
+ .env
103
+ .venv
104
+ env/
105
+ venv/
106
+ ENV/
107
+ env.bak/
108
+ venv.bak/
109
+
110
+ # Spyder project settings
111
+ .spyderproject
112
+ .spyproject
113
+
114
+ # Rope project settings
115
+ .ropeproject
116
+
117
+ # mkdocs documentation
118
+ /site
119
+
120
+ # mypy
121
+ .mypy_cache/
122
+ .dmypy.json
123
+ dmypy.json
124
+
125
+ # Pyre type checker
126
+ .pyre/
127
+
128
+ # vscode
129
+ .vs
130
+ .vscode
131
+
132
+ # Pycharm
133
+ .idea
134
+
135
+ # TF code
136
+ tensorflow_code
137
+
138
+ # Models
139
+ proc_data
140
+
141
+ # examples
142
+ runs
143
+ /runs_old
144
+ /wandb
145
+ /examples/runs
146
+ /examples/**/*.args
147
+ /examples/rag/sweep
148
+
149
+ # data
150
+ /data
151
+ serialization_dir
152
+
153
+ # emacs
154
+ *.*~
155
+ debug.env
156
+
157
+ # vim
158
+ .*.swp
159
+
160
+ #ctags
161
+ tags
162
+
163
+ # pre-commit
164
+ .pre-commit*
165
+
166
+ # .lock
167
+ *.lock
168
+
169
+ # DS_Store (MacOS)
170
+ .DS_Store
171
+ # RL pipelines may produce mp4 outputs
172
+ *.mp4
173
+
174
+ # dependencies
175
+ /transformers
app_gradio.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Most code is from https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI
2
+
3
+ #!/usr/bin/env python
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+ from subprocess import getoutput
9
+
10
+ import gradio as gr
11
+ import torch
12
+
13
+ from gradio_demo.app_running import create_demo
14
+ from gradio_demo.runner import Runner
15
+
16
+ TITLE = '# [vid2vid-zero](https://github.com/baaivision/vid2vid-zero)'
17
+
18
+ ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ GPU_DATA = getoutput('nvidia-smi')
21
+
22
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
23
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
24
+ else:
25
+ SETTINGS = 'Settings'
26
+
27
+ CUDA_NOT_AVAILABLE_WARNING = f'''## Attention - Running on CPU.
28
+ <center>
29
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
30
+ You can use "T4 small/medium" to run this demo.
31
+ </center>
32
+ '''
33
+
34
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''The environment variable `HF_TOKEN` is not specified. Feel free to specify your Hugging Face token with write permission if you don't want to manually provide it for every run.
35
+ <center>
36
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
37
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
38
+ </center>
39
+ '''
40
+
41
+ HF_TOKEN = os.getenv('HF_TOKEN')
42
+
43
+
44
+ def show_warning(warning_text: str) -> gr.Blocks:
45
+ with gr.Blocks() as demo:
46
+ with gr.Box():
47
+ gr.Markdown(warning_text)
48
+ return demo
49
+
50
+
51
+ pipe = None
52
+ runner = Runner(HF_TOKEN)
53
+
54
+ with gr.Blocks(css='gradio_demo/style.css') as demo:
55
+ if not torch.cuda.is_available():
56
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
57
+
58
+ gr.Markdown(TITLE)
59
+ with gr.Tabs():
60
+ with gr.TabItem('Zero-shot Testing'):
61
+ create_demo(runner, pipe)
62
+
63
+ if not HF_TOKEN:
64
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
65
+
66
+ demo.queue(max_size=1).launch(share=True)
configs/black-swan.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/black-swan
3
+ input_data:
4
+ video_path: data/black-swan.mp4
5
+ prompt: a blackswan is swimming on the water
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 4
11
+ validation_data:
12
+ prompts:
13
+ - a black swan is swimming on the water, Van Gogh style
14
+ - a white swan is swimming on the water
15
+ video_length: 8
16
+ width: 512
17
+ height: 512
18
+ num_inference_steps: 50
19
+ guidance_scale: 7.5
20
+ num_inv_steps: 50
21
+ # args for null-text inv
22
+ use_null_inv: True
23
+ null_inner_steps: 1
24
+ null_base_lr: 1e-2
25
+ null_uncond_ratio: -0.5
26
+ null_normal_infer: True
27
+
28
+ input_batch_size: 1
29
+ seed: 33
30
+ mixed_precision: "no"
31
+ gradient_checkpointing: True
32
+ enable_xformers_memory_efficient_attention: True
33
+ # test-time adaptation
34
+ use_sc_attn: True
35
+ use_st_attn: True
36
+ st_attn_idx: 0
configs/brown-bear.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/brown-bear
3
+ input_data:
4
+ video_path: data/brown-bear.mp4
5
+ prompt: a brown bear is sitting on the ground
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 1
11
+ validation_data:
12
+ prompts:
13
+ - a brown bear is sitting on the grass
14
+ - a black bear is sitting on the grass
15
+ - a polar bear is sitting on the ground
16
+ video_length: 8
17
+ width: 512
18
+ height: 512
19
+ num_inference_steps: 50
20
+ guidance_scale: 7.5
21
+ num_inv_steps: 50
22
+ # args for null-text inv
23
+ use_null_inv: True
24
+ null_inner_steps: 1
25
+ null_base_lr: 1e-2
26
+ null_uncond_ratio: -0.5
27
+ null_normal_infer: True
28
+
29
+ input_batch_size: 1
30
+ seed: 33
31
+ mixed_precision: "no"
32
+ gradient_checkpointing: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ # test-time adaptation
35
+ use_sc_attn: True
36
+ use_st_attn: True
37
+ st_attn_idx: 0
configs/car-moving.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/car-moving
3
+ input_data:
4
+ video_path: data/car-moving.mp4
5
+ prompt: a car is moving on the road
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 1
11
+ validation_data:
12
+ prompts:
13
+ - a car is moving on the snow
14
+ - a jeep car is moving on the road
15
+ - a jeep car is moving on the desert
16
+ video_length: 8
17
+ width: 512
18
+ height: 512
19
+ num_inference_steps: 50
20
+ guidance_scale: 7.5
21
+ num_inv_steps: 50
22
+ # args for null-text inv
23
+ use_null_inv: True
24
+ null_inner_steps: 1
25
+ null_base_lr: 1e-2
26
+ null_uncond_ratio: -0.5
27
+ null_normal_infer: True
28
+
29
+ input_batch_size: 1
30
+ seed: 33
31
+ mixed_precision: "no"
32
+ gradient_checkpointing: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ # test-time adaptation
35
+ use_sc_attn: True
36
+ use_st_attn: True
37
+ st_attn_idx: 0
configs/car-turn.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: "outputs/car-turn"
3
+
4
+ input_data:
5
+ video_path: "data/car-turn.mp4"
6
+ prompt: "a jeep car is moving on the road"
7
+ n_sample_frames: 8
8
+ width: 512
9
+ height: 512
10
+ sample_start_idx: 0
11
+ sample_frame_rate: 6
12
+
13
+ validation_data:
14
+ prompts:
15
+ - "a jeep car is moving on the beach"
16
+ - "a jeep car is moving on the snow"
17
+ - "a Porsche car is moving on the desert"
18
+ video_length: 8
19
+ width: 512
20
+ height: 512
21
+ num_inference_steps: 50
22
+ guidance_scale: 7.5
23
+ num_inv_steps: 50
24
+ # args for null-text inv
25
+ use_null_inv: True
26
+ null_inner_steps: 1
27
+ null_base_lr: 1e-2
28
+ null_uncond_ratio: -0.5
29
+ null_normal_infer: True
30
+
31
+ input_batch_size: 1
32
+ seed: 33
33
+ mixed_precision: "no"
34
+ gradient_checkpointing: True
35
+ enable_xformers_memory_efficient_attention: True
36
+ # test-time adaptation
37
+ use_sc_attn: True
38
+ use_st_attn: True
39
+ st_attn_idx: 0
configs/child-riding.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
3
+ output_dir: outputs/child-riding
4
+
5
+ input_data:
6
+ video_path: data/child-riding.mp4
7
+ prompt: "a child is riding a bike on the road"
8
+ n_sample_frames: 8
9
+ width: 512
10
+ height: 512
11
+ sample_start_idx: 0
12
+ sample_frame_rate: 1
13
+
14
+ validation_data:
15
+ # inv_latent: "outputs_2d/car-turn/inv_latents/ddim_latent-0.pt" # latent inversed w/o SCAttn !
16
+ prompts:
17
+ - a lego child is riding a bike on the road
18
+ - a child is riding a bike on the flooded road
19
+ video_length: 8
20
+ width: 512
21
+ height: 512
22
+ num_inference_steps: 50
23
+ guidance_scale: 7.5
24
+ num_inv_steps: 50
25
+ # args for null-text inv
26
+ use_null_inv: True
27
+ null_inner_steps: 1
28
+ null_base_lr: 1e-2
29
+ null_uncond_ratio: -0.5
30
+ null_normal_infer: True
31
+
32
+ input_batch_size: 1
33
+ seed: 33
34
+ mixed_precision: "no"
35
+ gradient_checkpointing: True
36
+ enable_xformers_memory_efficient_attention: True
37
+ # test-time adaptation
38
+ use_sc_attn: True
39
+ use_st_attn: True
40
+ st_attn_idx: 0
configs/cow-walking.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/cow-walking
3
+ input_data:
4
+ video_path: data/cow-walking.mp4
5
+ prompt: a cow is walking on the grass
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 2
11
+ validation_data:
12
+ prompts:
13
+ - a lion is walking on the grass
14
+ - a dog is walking on the grass
15
+ - a cow is walking on the snow
16
+ video_length: 8
17
+ width: 512
18
+ height: 512
19
+ num_inference_steps: 50
20
+ guidance_scale: 7.5
21
+ num_inv_steps: 50
22
+ # args for null-text inv
23
+ use_null_inv: True
24
+ null_inner_steps: 1
25
+ null_base_lr: 1e-2
26
+ null_uncond_ratio: -0.5
27
+ null_normal_infer: True
28
+
29
+ input_batch_size: 1
30
+ seed: 33
31
+ mixed_precision: "no"
32
+ gradient_checkpointing: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ # test-time adaptation
35
+ use_sc_attn: True
36
+ use_st_attn: True
37
+ st_attn_idx: 0
configs/dog-walking.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/dog_walking
3
+ input_data:
4
+ video_path: data/dog-walking.mp4
5
+ prompt: a dog is walking on the ground
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 15
10
+ sample_frame_rate: 3
11
+ validation_data:
12
+ prompts:
13
+ - a dog is walking on the ground, Van Gogh style
14
+ video_length: 8
15
+ width: 512
16
+ height: 512
17
+ num_inference_steps: 50
18
+ guidance_scale: 7.5
19
+ num_inv_steps: 50
20
+ # args for null-text inv
21
+ use_null_inv: True
22
+ null_inner_steps: 1
23
+ null_base_lr: 1e-2
24
+ null_uncond_ratio: -0.5
25
+ null_normal_infer: True
26
+
27
+ input_batch_size: 1
28
+ seed: 33
29
+ mixed_precision: "no"
30
+ gradient_checkpointing: True
31
+ enable_xformers_memory_efficient_attention: True
32
+ # test-time adaptation
33
+ use_sc_attn: True
34
+ use_st_attn: True
35
+ st_attn_idx: 0
configs/horse-running.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/horse-running
3
+ input_data:
4
+ video_path: data/horse-running.mp4
5
+ prompt: a horse is running on the beach
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 2
11
+ validation_data:
12
+ prompts:
13
+ - a dog is running on the beach
14
+ - a dog is running on the desert
15
+ video_length: 8
16
+ width: 512
17
+ height: 512
18
+ num_inference_steps: 50
19
+ guidance_scale: 7.5
20
+ num_inv_steps: 50
21
+ # args for null-text inv
22
+ use_null_inv: True
23
+ null_inner_steps: 1
24
+ null_base_lr: 1e-2
25
+ null_uncond_ratio: -0.5
26
+ null_normal_infer: True
27
+
28
+ input_batch_size: 1
29
+ seed: 33
30
+ mixed_precision: "no"
31
+ gradient_checkpointing: True
32
+ enable_xformers_memory_efficient_attention: True
33
+ # test-time adaptation
34
+ use_sc_attn: True
35
+ use_st_attn: True
36
+ st_attn_idx: 0
configs/lion-roaring.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: ./outputs/lion-roaring
3
+ input_data:
4
+ video_path: data/lion-roaring.mp4
5
+ prompt: a lion is roaring
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 2
11
+ validation_data:
12
+ prompts:
13
+ - a lego lion is roaring
14
+ - a wolf is roaring, anime style
15
+ - a lion is roaring, anime style
16
+ video_length: 8
17
+ width: 512
18
+ height: 512
19
+ num_inference_steps: 50
20
+ guidance_scale: 7.5
21
+ num_inv_steps: 50
22
+ # args for null-text inv
23
+ use_null_inv: True
24
+ null_inner_steps: 1
25
+ null_base_lr: 1e-2
26
+ null_uncond_ratio: -0.5
27
+ null_normal_infer: True
28
+
29
+ input_batch_size: 1
30
+ seed: 33
31
+ mixed_precision: "no"
32
+ gradient_checkpointing: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ # test-time adaptation
35
+ use_sc_attn: True
36
+ use_st_attn: True
37
+ st_attn_idx: 0
configs/man-running.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/man-running
3
+ input_data:
4
+ video_path: data/man-running.mp4
5
+ prompt: a man is running
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 25
10
+ sample_frame_rate: 2
11
+ validation_data:
12
+ prompts:
13
+ - Stephen Curry is running in Time Square
14
+ - a man is running, Van Gogh style
15
+ - a man is running in New York City
16
+ video_length: 8
17
+ width: 512
18
+ height: 512
19
+ num_inference_steps: 50
20
+ guidance_scale: 7.5
21
+ num_inv_steps: 50
22
+ # args for null-text inv
23
+ use_null_inv: True
24
+ null_inner_steps: 1
25
+ null_base_lr: 1e-2
26
+ null_uncond_ratio: -0.5
27
+ null_normal_infer: True
28
+
29
+ input_batch_size: 1
30
+ seed: 33
31
+ mixed_precision: "no"
32
+ gradient_checkpointing: True
33
+ enable_xformers_memory_efficient_attention: True
34
+ # test-time adaptation
35
+ use_sc_attn: True
36
+ use_st_attn: True
37
+ st_attn_idx: 0
configs/man-surfing.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/man-surfing
3
+ input_data:
4
+ video_path: data/man-surfing.mp4
5
+ prompt: a man is surfing
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 3
11
+ validation_data:
12
+ prompts:
13
+ - a boy is surfing in the desert
14
+ - Iron Man is surfing is surfing
15
+ video_length: 8
16
+ width: 512
17
+ height: 512
18
+ num_inference_steps: 50
19
+ guidance_scale: 7.5
20
+ num_inv_steps: 50
21
+ # args for null-text inv
22
+ use_null_inv: True
23
+ null_inner_steps: 1
24
+ null_base_lr: 1e-2
25
+ null_uncond_ratio: -0.5
26
+ null_normal_infer: True
27
+
28
+ input_batch_size: 1
29
+ seed: 33
30
+ mixed_precision: "no"
31
+ gradient_checkpointing: True
32
+ enable_xformers_memory_efficient_attention: True
33
+ # test-time adaptation
34
+ use_sc_attn: True
35
+ use_st_attn: True
36
+ st_attn_idx: 0
configs/rabbit-watermelon.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: "outputs/rabbit-watermelon"
3
+
4
+ input_data:
5
+ video_path: "data/rabbit-watermelon.mp4"
6
+ prompt: "a rabbit is eating a watermelon"
7
+ n_sample_frames: 8
8
+ width: 512
9
+ height: 512
10
+ sample_start_idx: 0
11
+ sample_frame_rate: 6
12
+
13
+ validation_data:
14
+ prompts:
15
+ - "a tiger is eating a watermelon"
16
+ - "a rabbit is eating an orange"
17
+ - "a rabbit is eating a pizza"
18
+ - "a puppy is eating an orange"
19
+ video_length: 8
20
+ width: 512
21
+ height: 512
22
+ num_inference_steps: 50
23
+ guidance_scale: 7.5
24
+ num_inv_steps: 50
25
+ # args for null-text inv
26
+ use_null_inv: True
27
+ null_inner_steps: 1
28
+ null_base_lr: 1e-2
29
+ null_uncond_ratio: -0.5
30
+ null_normal_infer: True
31
+
32
+ input_batch_size: 1
33
+ seed: 33
34
+ mixed_precision: "no"
35
+ gradient_checkpointing: True
36
+ enable_xformers_memory_efficient_attention: True
37
+ # test-time adaptation
38
+ use_sc_attn: True
39
+ use_st_attn: True
40
+ st_attn_idx: 0
configs/skateboard-dog.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/skateboard-dog
3
+ input_data:
4
+ video_path: data/skateboard-dog.avi
5
+ prompt: A man with a dog skateboarding on the road
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 3
11
+ validation_data:
12
+ prompts:
13
+ - A man with a dog skateboarding on the desert
14
+ video_length: 8
15
+ width: 512
16
+ height: 512
17
+ num_inference_steps: 50
18
+ guidance_scale: 7.5
19
+ num_inv_steps: 50
20
+ # args for null-text inv
21
+ use_null_inv: True
22
+ null_inner_steps: 1
23
+ null_base_lr: 1e-2
24
+ null_uncond_ratio: -0.5
25
+ null_normal_infer: True
26
+
27
+ input_batch_size: 1
28
+ seed: 33
29
+ mixed_precision: "no"
30
+ gradient_checkpointing: True
31
+ enable_xformers_memory_efficient_attention: True
32
+ # test-time adaptation
33
+ use_sc_attn: True
34
+ use_st_attn: True
35
+ st_attn_idx: 0
configs/skateboard-man.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_path: checkpoints/stable-diffusion-v1-4
2
+ output_dir: outputs/skateboard-man
3
+ input_data:
4
+ video_path: data/skateboard-man.mp4
5
+ prompt: a man is playing skateboard on the ground
6
+ n_sample_frames: 8
7
+ width: 512
8
+ height: 512
9
+ sample_start_idx: 0
10
+ sample_frame_rate: 3
11
+ validation_data:
12
+ prompts:
13
+ - a boy is playing skateboard on the ground
14
+ video_length: 8
15
+ width: 512
16
+ height: 512
17
+ num_inference_steps: 50
18
+ guidance_scale: 7.5
19
+ num_inv_steps: 50
20
+ # args for null-text inv
21
+ use_null_inv: True
22
+ null_inner_steps: 1
23
+ null_base_lr: 1e-2
24
+ null_uncond_ratio: -0.5
25
+ null_normal_infer: True
26
+
27
+ input_batch_size: 1
28
+ seed: 33
29
+ mixed_precision: "no"
30
+ gradient_checkpointing: True
31
+ enable_xformers_memory_efficient_attention: True
32
+ # test-time adaptation
33
+ use_sc_attn: True
34
+ use_st_attn: True
35
+ st_attn_idx: 0
gradio_demo/app_running.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from gradio_demo.runner import Runner
10
+
11
+
12
+ def create_demo(runner: Runner,
13
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
14
+ hf_token = os.getenv('HF_TOKEN')
15
+ with gr.Blocks() as demo:
16
+ with gr.Row():
17
+ with gr.Column():
18
+ with gr.Box():
19
+ gr.Markdown('Input Data')
20
+ input_video = gr.File(label='Input video')
21
+ input_prompt = gr.Textbox(
22
+ label='Input prompt',
23
+ max_lines=1,
24
+ placeholder='A car is moving on the road.')
25
+ gr.Markdown('''
26
+ - Upload a video and write a `Input Prompt` that describes the video.
27
+ ''')
28
+
29
+ with gr.Column():
30
+ with gr.Box():
31
+ gr.Markdown('Input Parameters')
32
+ with gr.Row():
33
+ model_path = gr.Text(
34
+ label='Path to off-the-shelf model',
35
+ value='CompVis/stable-diffusion-v1-4',
36
+ max_lines=1)
37
+ resolution = gr.Dropdown(choices=['512', '768'],
38
+ value='512',
39
+ label='Resolution',
40
+ visible=False)
41
+
42
+ with gr.Accordion('Advanced settings', open=False):
43
+ sample_start_idx = gr.Number(
44
+ label='Start Frame Index',value=0)
45
+ sample_frame_rate = gr.Number(
46
+ label='Frame Rate',value=1)
47
+ n_sample_frames = gr.Number(
48
+ label='Number of Frames',value=8)
49
+ guidance_scale = gr.Number(
50
+ label='Guidance Scale', value=7.5)
51
+ seed = gr.Slider(label='Seed',
52
+ minimum=0,
53
+ maximum=100000,
54
+ step=1,
55
+ randomize=True,
56
+ value=33)
57
+ input_token = gr.Text(label='Hugging Face Write Token',
58
+ placeholder='',
59
+ visible=False if hf_token else True)
60
+ gr.Markdown('''
61
+ - Upload input video or choose an exmple blow
62
+ - Set hyperparameters & click start
63
+ - It takes a few minutes to download model first
64
+ ''')
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ validation_prompt = gr.Text(
69
+ label='Validation Prompt',
70
+ placeholder=
71
+ 'prompt to test the model, e.g: a Lego man is surfing')
72
+
73
+ remove_gpu_after_running = gr.Checkbox(
74
+ label='Remove GPU after running',
75
+ value=False,
76
+ interactive=bool(os.getenv('SPACE_ID')),
77
+ visible=False)
78
+
79
+ with gr.Row():
80
+ result = gr.Video(label='Result')
81
+
82
+ # examples
83
+ with gr.Row():
84
+ examples = [
85
+ [
86
+ 'CompVis/stable-diffusion-v1-4',
87
+ "data/car-moving.mp4",
88
+ 'A car is moving on the road.',
89
+ 8, 0, 1,
90
+ 'A jeep car is moving on the desert.',
91
+ 7.5, 512, 33,
92
+ False, None,
93
+ ],
94
+
95
+ [
96
+ 'CompVis/stable-diffusion-v1-4',
97
+ "data/black-swan.mp4",
98
+ 'A blackswan is swimming on the water.',
99
+ 8, 0, 4,
100
+ 'A white swan is swimming on the water.',
101
+ 7.5, 512, 33,
102
+ False, None,
103
+ ],
104
+
105
+ [
106
+ 'CompVis/stable-diffusion-v1-4',
107
+ "data/child-riding.mp4",
108
+ 'A child is riding a bike on the road.',
109
+ 8, 0, 1,
110
+ 'A lego child is riding a bike on the road.',
111
+ 7.5, 512, 33,
112
+ False, None,
113
+ ],
114
+
115
+ [
116
+ 'CompVis/stable-diffusion-v1-4',
117
+ "data/car-turn.mp4",
118
+ 'A jeep car is moving on the road.',
119
+ 8, 0, 6,
120
+ 'A jeep car is moving on the snow.',
121
+ 7.5, 512, 33,
122
+ False, None,
123
+ ],
124
+
125
+ [
126
+ 'CompVis/stable-diffusion-v1-4',
127
+ "data/rabbit-watermelon.mp4",
128
+ 'A rabbit is eating a watermelon.',
129
+ 8, 0, 6,
130
+ 'A puppy is eating an orange.',
131
+ 7.5, 512, 33,
132
+ False, None,
133
+ ],
134
+
135
+ [
136
+ 'CompVis/stable-diffusion-v1-4',
137
+ "data/brown-bear.mp4",
138
+ 'A brown bear is sitting on the ground.',
139
+ 8, 0, 6,
140
+ 'A black bear is sitting on the grass.',
141
+ 7.5, 512, 33,
142
+ False, None,
143
+ ],
144
+ ]
145
+ gr.Examples(examples=examples,
146
+ fn=runner.run_vid2vid_zero,
147
+ inputs=[
148
+ model_path, input_video, input_prompt,
149
+ n_sample_frames, sample_start_idx, sample_frame_rate,
150
+ validation_prompt, guidance_scale, resolution, seed,
151
+ remove_gpu_after_running,
152
+ input_token,
153
+ ],
154
+ outputs=result,
155
+ cache_examples=os.getenv('SYSTEM') == 'spaces'
156
+ )
157
+
158
+ # run
159
+ run_button_vid2vid_zero = gr.Button('Start vid2vid-zero')
160
+ run_button_vid2vid_zero.click(
161
+ fn=runner.run_vid2vid_zero,
162
+ inputs=[
163
+ model_path, input_video, input_prompt,
164
+ n_sample_frames, sample_start_idx, sample_frame_rate,
165
+ validation_prompt, guidance_scale, resolution, seed,
166
+ remove_gpu_after_running,
167
+ input_token,
168
+ ],
169
+ outputs=result)
170
+
171
+ return demo
172
+
173
+
174
+ if __name__ == '__main__':
175
+ hf_token = os.getenv('HF_TOKEN')
176
+ runner = Runner(hf_token)
177
+ demo = create_demo(runner)
178
+ demo.queue(max_size=1).launch(share=False)
gradio_demo/runner.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+
11
+ import gradio as gr
12
+ import slugify
13
+ import torch
14
+ import huggingface_hub
15
+ from huggingface_hub import HfApi
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ ORIGINAL_SPACE_ID = 'BAAI/vid2vid-zero'
20
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
21
+
22
+
23
+ class Runner:
24
+ def __init__(self, hf_token: str | None = None):
25
+ self.hf_token = hf_token
26
+
27
+ self.checkpoint_dir = pathlib.Path('checkpoints')
28
+ self.checkpoint_dir.mkdir(exist_ok=True)
29
+
30
+ def download_base_model(self, base_model_id: str, token=None) -> str:
31
+ model_dir = self.checkpoint_dir / base_model_id
32
+ if not model_dir.exists():
33
+ org_name = base_model_id.split('/')[0]
34
+ org_dir = self.checkpoint_dir / org_name
35
+ org_dir.mkdir(exist_ok=True)
36
+ print(f'https://huggingface.co/{base_model_id}')
37
+ if token == None:
38
+ subprocess.run(shlex.split(
39
+ f'git clone https://huggingface.co/{base_model_id}'),
40
+ cwd=org_dir)
41
+ return model_dir.as_posix()
42
+ else:
43
+ temp_path = huggingface_hub.snapshot_download(base_model_id, use_auth_token=token)
44
+ print(temp_path, org_dir)
45
+ # subprocess.run(shlex.split(f'mv {temp_path} {model_dir.as_posix()}'))
46
+ # return model_dir.as_posix()
47
+ return temp_path
48
+
49
+ def join_model_library_org(self, token: str) -> None:
50
+ subprocess.run(
51
+ shlex.split(
52
+ f'curl -X POST -H "Authorization: Bearer {token}" -H "Content-Type: application/json" {URL_TO_JOIN_MODEL_LIBRARY_ORG}'
53
+ ))
54
+
55
+ def run_vid2vid_zero(
56
+ self,
57
+ model_path: str,
58
+ input_video: str,
59
+ prompt: str,
60
+ n_sample_frames: int,
61
+ sample_start_idx: int,
62
+ sample_frame_rate: int,
63
+ validation_prompt: str,
64
+ guidance_scale: float,
65
+ resolution: str,
66
+ seed: int,
67
+ remove_gpu_after_running: bool,
68
+ input_token: str = None,
69
+ ) -> str:
70
+
71
+ if not torch.cuda.is_available():
72
+ raise gr.Error('CUDA is not available.')
73
+ if input_video is None:
74
+ raise gr.Error('You need to upload a video.')
75
+ if not prompt:
76
+ raise gr.Error('The input prompt is missing.')
77
+ if not validation_prompt:
78
+ raise gr.Error('The validation prompt is missing.')
79
+
80
+ resolution = int(resolution)
81
+ n_sample_frames = int(n_sample_frames)
82
+ sample_start_idx = int(sample_start_idx)
83
+ sample_frame_rate = int(sample_frame_rate)
84
+
85
+ repo_dir = pathlib.Path(__file__).parent
86
+ prompt_path = prompt.replace(' ', '_')
87
+ output_dir = repo_dir / 'outputs' / prompt_path
88
+ output_dir.mkdir(parents=True, exist_ok=True)
89
+
90
+ config = OmegaConf.load('configs/black-swan.yaml')
91
+ # config.pretrained_model_path = self.download_base_model(model_path, token=input_token)
92
+ config.pretrained_model_path = "checkpoints/stable-diffusion-v1-4" # TODO
93
+
94
+ config.output_dir = output_dir.as_posix()
95
+ config.input_data.video_path = input_video.name # type: ignore
96
+ config.input_data.prompt = prompt
97
+ config.input_data.n_sample_frames = n_sample_frames
98
+ config.input_data.width = resolution
99
+ config.input_data.height = resolution
100
+ config.input_data.sample_start_idx = sample_start_idx
101
+ config.input_data.sample_frame_rate = sample_frame_rate
102
+
103
+ config.validation_data.prompts = [validation_prompt]
104
+ config.validation_data.video_length = 8
105
+ config.validation_data.width = resolution
106
+ config.validation_data.height = resolution
107
+ config.validation_data.num_inference_steps = 50
108
+ config.validation_data.guidance_scale = guidance_scale
109
+
110
+ config.input_batch_size = 1
111
+ config.seed = seed
112
+
113
+ config_path = output_dir / 'config.yaml'
114
+ with open(config_path, 'w') as f:
115
+ OmegaConf.save(config, f)
116
+
117
+ command = f'accelerate launch test_vid2vid_zero.py --config {config_path}'
118
+ subprocess.run(shlex.split(command))
119
+
120
+ output_video_path = os.path.join(output_dir, "sample-all.mp4")
121
+ print(f"video path for gradio: {output_video_path}")
122
+ message = 'Running completed!'
123
+ print(message)
124
+
125
+ if remove_gpu_after_running:
126
+ space_id = os.getenv('SPACE_ID')
127
+ if space_id:
128
+ api = HfApi(
129
+ token=self.hf_token if self.hf_token else input_token)
130
+ api.request_space_hardware(repo_id=space_id,
131
+ hardware='cpu-basic')
132
+
133
+ return output_video_path
gradio_demo/style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+ diffusers[torch]==0.11.1
4
+ transformers>=4.25.1
5
+ bitsandbytes==0.35.4
6
+ decord==0.6.0
7
+ accelerate
8
+ tensorboard
9
+ modelcards
10
+ omegaconf
11
+ einops
12
+ imageio
13
+ ftfy
test_vid2vid_zero.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import logging
4
+ import inspect
5
+ import math
6
+ import os
7
+ import warnings
8
+ from typing import Dict, Optional, Tuple
9
+ from omegaconf import OmegaConf
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+
15
+ import diffusers
16
+ import transformers
17
+ from accelerate import Accelerator
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import set_seed
20
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
21
+ from diffusers.optimization import get_scheduler
22
+ from diffusers.utils import check_min_version
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from tqdm.auto import tqdm
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+
27
+ from vid2vid_zero.models.unet_2d_condition import UNet2DConditionModel
28
+ from vid2vid_zero.data.dataset import VideoDataset
29
+ from vid2vid_zero.pipelines.pipeline_vid2vid_zero import Vid2VidZeroPipeline
30
+ from vid2vid_zero.util import save_videos_grid, save_videos_as_images, ddim_inversion
31
+ from einops import rearrange
32
+
33
+ from vid2vid_zero.p2p.p2p_stable import AttentionReplace, AttentionRefine
34
+ from vid2vid_zero.p2p.ptp_utils import register_attention_control
35
+ from vid2vid_zero.p2p.null_text_w_ptp import NullInversion
36
+
37
+
38
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
39
+ check_min_version("0.10.0.dev0")
40
+
41
+ logger = get_logger(__name__, log_level="INFO")
42
+
43
+
44
+ def prepare_control(unet, prompts, validation_data):
45
+ assert len(prompts) == 2
46
+
47
+ print(prompts[0])
48
+ print(prompts[1])
49
+ length1 = len(prompts[0].split(' '))
50
+ length2 = len(prompts[1].split(' '))
51
+ if length1 == length2:
52
+ # prepare for attn guidance
53
+ cross_replace_steps = 0.8
54
+ self_replace_steps = 0.4
55
+ controller = AttentionReplace(prompts, validation_data['num_inference_steps'],
56
+ cross_replace_steps=cross_replace_steps,
57
+ self_replace_steps=self_replace_steps)
58
+ else:
59
+ cross_replace_steps = 0.8
60
+ self_replace_steps = 0.4
61
+ controller = AttentionRefine(prompts, validation_data['num_inference_steps'],
62
+ cross_replace_steps=self_replace_steps,
63
+ self_replace_steps=self_replace_steps)
64
+
65
+ print(controller)
66
+ register_attention_control(unet, controller)
67
+
68
+ # the update of unet forward function is inplace
69
+ return cross_replace_steps, self_replace_steps
70
+
71
+
72
+ def main(
73
+ pretrained_model_path: str,
74
+ output_dir: str,
75
+ input_data: Dict,
76
+ validation_data: Dict,
77
+ input_batch_size: int = 1,
78
+ gradient_accumulation_steps: int = 1,
79
+ gradient_checkpointing: bool = True,
80
+ mixed_precision: Optional[str] = "fp16",
81
+ enable_xformers_memory_efficient_attention: bool = True,
82
+ seed: Optional[int] = None,
83
+ use_sc_attn: bool = True,
84
+ use_st_attn: bool = True,
85
+ st_attn_idx: int = 0,
86
+ fps: int = 1,
87
+ ):
88
+ *_, config = inspect.getargvalues(inspect.currentframe())
89
+
90
+ accelerator = Accelerator(
91
+ gradient_accumulation_steps=gradient_accumulation_steps,
92
+ mixed_precision=mixed_precision,
93
+ )
94
+
95
+ # Make one log on every process with the configuration for debugging.
96
+ logging.basicConfig(
97
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
98
+ datefmt="%m/%d/%Y %H:%M:%S",
99
+ level=logging.INFO,
100
+ )
101
+ logger.info(accelerator.state, main_process_only=False)
102
+ if accelerator.is_local_main_process:
103
+ transformers.utils.logging.set_verbosity_warning()
104
+ diffusers.utils.logging.set_verbosity_info()
105
+ else:
106
+ transformers.utils.logging.set_verbosity_error()
107
+ diffusers.utils.logging.set_verbosity_error()
108
+
109
+ # If passed along, set the training seed now.
110
+ if seed is not None:
111
+ set_seed(seed)
112
+
113
+ # Handle the output folder creation
114
+ if accelerator.is_main_process:
115
+ os.makedirs(output_dir, exist_ok=True)
116
+ os.makedirs(f"{output_dir}/sample", exist_ok=True)
117
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
118
+
119
+ # Load tokenizer and models.
120
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
121
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
122
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
123
+ unet = UNet2DConditionModel.from_pretrained(
124
+ pretrained_model_path, subfolder="unet", use_sc_attn=use_sc_attn,
125
+ use_st_attn=use_st_attn, st_attn_idx=st_attn_idx)
126
+
127
+ # Freeze vae, text_encoder, and unet
128
+ vae.requires_grad_(False)
129
+ text_encoder.requires_grad_(False)
130
+ unet.requires_grad_(False)
131
+
132
+ if enable_xformers_memory_efficient_attention:
133
+ if is_xformers_available():
134
+ unet.enable_xformers_memory_efficient_attention()
135
+ else:
136
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
137
+
138
+ if gradient_checkpointing:
139
+ unet.enable_gradient_checkpointing()
140
+
141
+ # Get the training dataset
142
+ input_dataset = VideoDataset(**input_data)
143
+
144
+ # Preprocessing the dataset
145
+ input_dataset.prompt_ids = tokenizer(
146
+ input_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
147
+ ).input_ids[0]
148
+
149
+ # DataLoaders creation:
150
+ input_dataloader = torch.utils.data.DataLoader(
151
+ input_dataset, batch_size=input_batch_size
152
+ )
153
+
154
+ # Get the validation pipeline
155
+ validation_pipeline = Vid2VidZeroPipeline(
156
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
157
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"),
158
+ safety_checker=None, feature_extractor=None,
159
+ )
160
+ validation_pipeline.enable_vae_slicing()
161
+ ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler')
162
+ ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps)
163
+
164
+ # Prepare everything with our `accelerator`.
165
+ unet, input_dataloader = accelerator.prepare(
166
+ unet, input_dataloader,
167
+ )
168
+
169
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
170
+ # as these models are only used for inference, keeping weights in full precision is not required.
171
+ weight_dtype = torch.float32
172
+ if accelerator.mixed_precision == "fp16":
173
+ weight_dtype = torch.float16
174
+ elif accelerator.mixed_precision == "bf16":
175
+ weight_dtype = torch.bfloat16
176
+
177
+ # Move text_encode and vae to gpu and cast to weight_dtype
178
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
179
+ vae.to(accelerator.device, dtype=weight_dtype)
180
+
181
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
182
+ num_update_steps_per_epoch = math.ceil(len(input_dataloader) / gradient_accumulation_steps)
183
+
184
+ # We need to initialize the trackers we use, and also store our configuration.
185
+ # The trackers initializes automatically on the main process.
186
+ if accelerator.is_main_process:
187
+ accelerator.init_trackers("vid2vid-zero")
188
+
189
+ # Zero-shot Eval!
190
+ total_batch_size = input_batch_size * accelerator.num_processes * gradient_accumulation_steps
191
+
192
+ logger.info("***** Running training *****")
193
+ logger.info(f" Num examples = {len(input_dataset)}")
194
+ logger.info(f" Instantaneous batch size per device = {input_batch_size}")
195
+ logger.info(f" Total input batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
196
+ global_step = 0
197
+
198
+ unet.eval()
199
+ for step, batch in enumerate(input_dataloader):
200
+ samples = []
201
+ pixel_values = batch["pixel_values"].to(weight_dtype)
202
+ # save input video
203
+ video = (pixel_values / 2 + 0.5).clamp(0, 1).detach().cpu()
204
+ video = video.permute(0, 2, 1, 3, 4) # (b, f, c, h, w)
205
+ samples.append(video)
206
+ # start processing
207
+ video_length = pixel_values.shape[1]
208
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
209
+ latents = vae.encode(pixel_values).latent_dist.sample()
210
+ # take video as input
211
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
212
+ latents = latents * 0.18215
213
+
214
+ generator = torch.Generator(device="cuda")
215
+ generator.manual_seed(seed)
216
+
217
+ # perform inversion
218
+ ddim_inv_latent = None
219
+ if validation_data.use_null_inv:
220
+ null_inversion = NullInversion(
221
+ model=validation_pipeline, guidance_scale=validation_data.guidance_scale, null_inv_with_prompt=False,
222
+ null_normal_infer=validation_data.null_normal_infer,
223
+ )
224
+ ddim_inv_latent, uncond_embeddings = null_inversion.invert(
225
+ latents, input_dataset.prompt, verbose=True,
226
+ null_inner_steps=validation_data.null_inner_steps,
227
+ null_base_lr=validation_data.null_base_lr,
228
+ )
229
+ ddim_inv_latent = ddim_inv_latent.to(weight_dtype)
230
+ uncond_embeddings = [embed.to(weight_dtype) for embed in uncond_embeddings]
231
+ else:
232
+ ddim_inv_latent = ddim_inversion(
233
+ validation_pipeline, ddim_inv_scheduler, video_latent=latents,
234
+ num_inv_steps=validation_data.num_inv_steps, prompt="",
235
+ normal_infer=True, # we don't want to use scatn or denseattn for inversion, just use sd inferenece
236
+ )[-1].to(weight_dtype)
237
+ uncond_embeddings = None
238
+
239
+ ddim_inv_latent = ddim_inv_latent.repeat(2, 1, 1, 1, 1)
240
+
241
+ for idx, prompt in enumerate(validation_data.prompts):
242
+ prompts = [input_dataset.prompt, prompt] # a list of two prompts
243
+ cross_replace_steps, self_replace_steps = prepare_control(unet=unet, prompts=prompts, validation_data=validation_data)
244
+
245
+ sample = validation_pipeline(prompts, generator=generator, latents=ddim_inv_latent,
246
+ uncond_embeddings=uncond_embeddings,
247
+ **validation_data).images
248
+
249
+ assert sample.shape[0] == 2
250
+ sample_inv, sample_gen = sample.chunk(2)
251
+ # add input for vis
252
+ save_videos_grid(sample_gen, f"{output_dir}/sample/{prompts[1]}.gif", fps=fps)
253
+ samples.append(sample_gen)
254
+
255
+ samples = torch.concat(samples)
256
+ save_path = f"{output_dir}/sample-all.gif"
257
+ save_videos_grid(samples, save_path, fps=fps)
258
+ logger.info(f"Saved samples to {save_path}")
259
+
260
+
261
+ if __name__ == "__main__":
262
+ parser = argparse.ArgumentParser()
263
+ parser.add_argument("--config", type=str, default="./configs/vid2vid_zero.yaml")
264
+ args = parser.parse_args()
265
+
266
+ main(**OmegaConf.load(args.config))
vid2vid_zero/data/dataset.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import decord
2
+ decord.bridge.set_bridge('torch')
3
+
4
+ from torch.utils.data import Dataset
5
+ from einops import rearrange
6
+
7
+
8
+ class VideoDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ video_path: str,
12
+ prompt: str,
13
+ width: int = 512,
14
+ height: int = 512,
15
+ n_sample_frames: int = 8,
16
+ sample_start_idx: int = 0,
17
+ sample_frame_rate: int = 1,
18
+ ):
19
+ self.video_path = video_path
20
+ self.prompt = prompt
21
+ self.prompt_ids = None
22
+
23
+ self.width = width
24
+ self.height = height
25
+ self.n_sample_frames = n_sample_frames
26
+ self.sample_start_idx = sample_start_idx
27
+ self.sample_frame_rate = sample_frame_rate
28
+
29
+ def __len__(self):
30
+ return 1
31
+
32
+ def __getitem__(self, index):
33
+ # load and sample video frames
34
+ vr = decord.VideoReader(self.video_path, width=self.width, height=self.height)
35
+ sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames]
36
+ video = vr.get_batch(sample_index)
37
+ video = rearrange(video, "f h w c -> f c h w")
38
+
39
+ example = {
40
+ "pixel_values": (video / 127.5 - 1.0),
41
+ "prompt_ids": self.prompt_ids
42
+ }
43
+
44
+ return example
vid2vid_zero/models/attention_2d.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.modeling_utils import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+
18
+
19
+ @dataclass
20
+ class Transformer2DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer2DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ sample_size: Optional[int] = None,
44
+ num_vector_embeds: Optional[int] = None,
45
+ activation_fn: str = "geglu",
46
+ num_embeds_ada_norm: Optional[int] = None,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ use_sc_attn: bool = False,
51
+ use_st_attn: bool = False,
52
+ ):
53
+ super().__init__()
54
+ self.use_linear_projection = use_linear_projection
55
+ self.num_attention_heads = num_attention_heads
56
+ self.attention_head_dim = attention_head_dim
57
+ inner_dim = num_attention_heads * attention_head_dim
58
+
59
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
60
+ # Define whether input is continuous or discrete depending on configuration
61
+ self.is_input_continuous = in_channels is not None
62
+ self.is_input_vectorized = num_vector_embeds is not None
63
+
64
+ if self.is_input_continuous and self.is_input_vectorized:
65
+ raise ValueError(
66
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
67
+ " sure that either `in_channels` or `num_vector_embeds` is None."
68
+ )
69
+ elif not self.is_input_continuous and not self.is_input_vectorized:
70
+ raise ValueError(
71
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
72
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
73
+ )
74
+
75
+ # 2. Define input layers
76
+ if self.is_input_continuous:
77
+ self.in_channels = in_channels
78
+
79
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
80
+ if use_linear_projection:
81
+ self.proj_in = nn.Linear(in_channels, inner_dim)
82
+ else:
83
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ # Define transformers blocks
88
+ self.transformer_blocks = nn.ModuleList(
89
+ [
90
+ BasicTransformerBlock(
91
+ inner_dim,
92
+ num_attention_heads,
93
+ attention_head_dim,
94
+ dropout=dropout,
95
+ cross_attention_dim=cross_attention_dim,
96
+ activation_fn=activation_fn,
97
+ num_embeds_ada_norm=num_embeds_ada_norm,
98
+ attention_bias=attention_bias,
99
+ only_cross_attention=only_cross_attention,
100
+ upcast_attention=upcast_attention,
101
+ use_sc_attn=use_sc_attn,
102
+ use_st_attn=True if (d == 0 and use_st_attn) else False ,
103
+ )
104
+ for d in range(num_layers)
105
+ ]
106
+ )
107
+
108
+ # 4. Define output layers
109
+ if use_linear_projection:
110
+ self.proj_out = nn.Linear(in_channels, inner_dim)
111
+ else:
112
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
113
+
114
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True, normal_infer: bool = False):
115
+ # Input
116
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
117
+ video_length = hidden_states.shape[2]
118
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
119
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
129
+ else:
130
+ inner_dim = hidden_states.shape[1]
131
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
132
+ hidden_states = self.proj_in(hidden_states)
133
+
134
+ # Blocks
135
+ for block in self.transformer_blocks:
136
+ hidden_states = block(
137
+ hidden_states,
138
+ encoder_hidden_states=encoder_hidden_states,
139
+ timestep=timestep,
140
+ video_length=video_length,
141
+ normal_infer=normal_infer,
142
+ )
143
+
144
+ # Output
145
+ if not self.use_linear_projection:
146
+ hidden_states = (
147
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
148
+ )
149
+ hidden_states = self.proj_out(hidden_states)
150
+ else:
151
+ hidden_states = self.proj_out(hidden_states)
152
+ hidden_states = (
153
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
154
+ )
155
+
156
+ output = hidden_states + residual
157
+
158
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
159
+ if not return_dict:
160
+ return (output,)
161
+
162
+ return Transformer2DModelOutput(sample=output)
163
+
164
+
165
+ class BasicTransformerBlock(nn.Module):
166
+ def __init__(
167
+ self,
168
+ dim: int,
169
+ num_attention_heads: int,
170
+ attention_head_dim: int,
171
+ dropout=0.0,
172
+ cross_attention_dim: Optional[int] = None,
173
+ activation_fn: str = "geglu",
174
+ num_embeds_ada_norm: Optional[int] = None,
175
+ attention_bias: bool = False,
176
+ only_cross_attention: bool = False,
177
+ upcast_attention: bool = False,
178
+ use_sc_attn: bool = False,
179
+ use_st_attn: bool = False,
180
+ ):
181
+ super().__init__()
182
+ self.only_cross_attention = only_cross_attention
183
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
184
+
185
+ # Attn with temporal modeling
186
+ self.use_sc_attn = use_sc_attn
187
+ self.use_st_attn = use_st_attn
188
+
189
+ attn_type = SparseCausalAttention if self.use_sc_attn else CrossAttention
190
+ attn_type = SpatialTemporalAttention if self.use_st_attn else attn_type
191
+ self.attn1 = attn_type(
192
+ query_dim=dim,
193
+ heads=num_attention_heads,
194
+ dim_head=attention_head_dim,
195
+ dropout=dropout,
196
+ bias=attention_bias,
197
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
198
+ upcast_attention=upcast_attention,
199
+ ) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
201
+
202
+ # Cross-Attn
203
+ if cross_attention_dim is not None:
204
+ self.attn2 = CrossAttention(
205
+ query_dim=dim,
206
+ cross_attention_dim=cross_attention_dim,
207
+ heads=num_attention_heads,
208
+ dim_head=attention_head_dim,
209
+ dropout=dropout,
210
+ bias=attention_bias,
211
+ upcast_attention=upcast_attention,
212
+ ) # is self-attn if encoder_hidden_states is none
213
+ else:
214
+ self.attn2 = None
215
+
216
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
217
+
218
+ if cross_attention_dim is not None:
219
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
220
+ else:
221
+ self.norm2 = None
222
+
223
+ # 3. Feed-forward
224
+ self.norm3 = nn.LayerNorm(dim)
225
+
226
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
227
+ if not is_xformers_available():
228
+ print("Here is how to install it")
229
+ raise ModuleNotFoundError(
230
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
231
+ " xformers",
232
+ name="xformers",
233
+ )
234
+ elif not torch.cuda.is_available():
235
+ raise ValueError(
236
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
237
+ " available for GPU "
238
+ )
239
+ else:
240
+ try:
241
+ # Make sure we can run the memory efficient attention
242
+ _ = xformers.ops.memory_efficient_attention(
243
+ torch.randn((1, 2, 40), device="cuda"),
244
+ torch.randn((1, 2, 40), device="cuda"),
245
+ torch.randn((1, 2, 40), device="cuda"),
246
+ )
247
+ except Exception as e:
248
+ raise e
249
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
250
+ if self.attn2 is not None:
251
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
252
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
253
+
254
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, normal_infer=False):
255
+ # SparseCausal-Attention
256
+ norm_hidden_states = (
257
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
258
+ )
259
+
260
+ if self.only_cross_attention:
261
+ hidden_states = (
262
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
263
+ )
264
+ else:
265
+ if self.use_sc_attn or self.use_st_attn:
266
+ hidden_states = self.attn1(
267
+ norm_hidden_states, attention_mask=attention_mask, video_length=video_length, normal_infer=normal_infer,
268
+ ) + hidden_states
269
+ else:
270
+ # shape of hidden_states: (b*f, len, dim)
271
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
272
+
273
+ if self.attn2 is not None:
274
+ # Cross-Attention
275
+ norm_hidden_states = (
276
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
277
+ )
278
+ hidden_states = (
279
+ self.attn2(
280
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
281
+ )
282
+ + hidden_states
283
+ )
284
+
285
+ # Feed-forward
286
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
287
+
288
+ return hidden_states
289
+
290
+
291
+ class SparseCausalAttention(CrossAttention):
292
+ def forward_sc_attn(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
293
+ batch_size, sequence_length, _ = hidden_states.shape
294
+
295
+ encoder_hidden_states = encoder_hidden_states
296
+
297
+ if self.group_norm is not None:
298
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
299
+
300
+ query = self.to_q(hidden_states)
301
+ dim = query.shape[-1]
302
+ query = self.reshape_heads_to_batch_dim(query)
303
+
304
+ if self.added_kv_proj_dim is not None:
305
+ raise NotImplementedError
306
+
307
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
308
+ key = self.to_k(encoder_hidden_states)
309
+ value = self.to_v(encoder_hidden_states)
310
+
311
+ former_frame_index = torch.arange(video_length) - 1
312
+ former_frame_index[0] = 0
313
+
314
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
315
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
316
+ key = rearrange(key, "b f d c -> (b f) d c")
317
+
318
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
319
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
320
+ value = rearrange(value, "b f d c -> (b f) d c")
321
+
322
+ key = self.reshape_heads_to_batch_dim(key)
323
+ value = self.reshape_heads_to_batch_dim(value)
324
+
325
+ if attention_mask is not None:
326
+ if attention_mask.shape[-1] != query.shape[1]:
327
+ target_length = query.shape[1]
328
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
329
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
330
+
331
+ # attention, what we cannot get enough of
332
+ if self._use_memory_efficient_attention_xformers:
333
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
334
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
335
+ hidden_states = hidden_states.to(query.dtype)
336
+ else:
337
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
338
+ hidden_states = self._attention(query, key, value, attention_mask)
339
+ else:
340
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
341
+
342
+ # linear proj
343
+ hidden_states = self.to_out[0](hidden_states)
344
+
345
+ # dropout
346
+ hidden_states = self.to_out[1](hidden_states)
347
+ return hidden_states
348
+
349
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, normal_infer=False):
350
+ if normal_infer:
351
+ return super().forward(
352
+ hidden_states=hidden_states,
353
+ encoder_hidden_states=encoder_hidden_states,
354
+ attention_mask=attention_mask,
355
+ # video_length=video_length,
356
+ )
357
+ else:
358
+ return self.forward_sc_attn(
359
+ hidden_states=hidden_states,
360
+ encoder_hidden_states=encoder_hidden_states,
361
+ attention_mask=attention_mask,
362
+ video_length=video_length,
363
+ )
364
+
365
+ class SpatialTemporalAttention(CrossAttention):
366
+ def forward_dense_attn(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
367
+ batch_size, sequence_length, _ = hidden_states.shape
368
+
369
+ encoder_hidden_states = encoder_hidden_states
370
+
371
+ if self.group_norm is not None:
372
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
373
+
374
+ query = self.to_q(hidden_states)
375
+ dim = query.shape[-1]
376
+ query = self.reshape_heads_to_batch_dim(query)
377
+
378
+ if self.added_kv_proj_dim is not None:
379
+ raise NotImplementedError
380
+
381
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
382
+ key = self.to_k(encoder_hidden_states)
383
+ value = self.to_v(encoder_hidden_states)
384
+
385
+ key = rearrange(key, "(b f) n d -> b f n d", f=video_length)
386
+ key = key.unsqueeze(1).repeat(1, video_length, 1, 1, 1) # (b f f n d)
387
+ key = rearrange(key, "b f g n d -> (b f) (g n) d")
388
+
389
+ value = rearrange(value, "(b f) n d -> b f n d", f=video_length)
390
+ value = value.unsqueeze(1).repeat(1, video_length, 1, 1, 1) # (b f f n d)
391
+ value = rearrange(value, "b f g n d -> (b f) (g n) d")
392
+
393
+ key = self.reshape_heads_to_batch_dim(key)
394
+ value = self.reshape_heads_to_batch_dim(value)
395
+
396
+ if attention_mask is not None:
397
+ if attention_mask.shape[-1] != query.shape[1]:
398
+ target_length = query.shape[1]
399
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
400
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
401
+
402
+ # attention, what we cannot get enough of
403
+ if self._use_memory_efficient_attention_xformers:
404
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
405
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
406
+ hidden_states = hidden_states.to(query.dtype)
407
+ else:
408
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
409
+ hidden_states = self._attention(query, key, value, attention_mask)
410
+ else:
411
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
412
+
413
+ # linear proj
414
+ hidden_states = self.to_out[0](hidden_states)
415
+
416
+ # dropout
417
+ hidden_states = self.to_out[1](hidden_states)
418
+ return hidden_states
419
+
420
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, normal_infer=False):
421
+ if normal_infer:
422
+ return super().forward(
423
+ hidden_states=hidden_states,
424
+ encoder_hidden_states=encoder_hidden_states,
425
+ attention_mask=attention_mask,
426
+ # video_length=video_length,
427
+ )
428
+ else:
429
+ return self.forward_dense_attn(
430
+ hidden_states=hidden_states,
431
+ encoder_hidden_states=encoder_hidden_states,
432
+ attention_mask=attention_mask,
433
+ video_length=video_length,
434
+ )
vid2vid_zero/models/resnet_2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample2D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ if name == "conv":
37
+ self.conv = conv
38
+ else:
39
+ self.Conv2d_0 = conv
40
+
41
+ def forward(self, hidden_states, output_size=None):
42
+ assert hidden_states.shape[1] == self.channels
43
+
44
+ if self.use_conv_transpose:
45
+ raise NotImplementedError
46
+
47
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
48
+ dtype = hidden_states.dtype
49
+ if dtype == torch.bfloat16:
50
+ hidden_states = hidden_states.to(torch.float32)
51
+
52
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
53
+ if hidden_states.shape[0] >= 64:
54
+ hidden_states = hidden_states.contiguous()
55
+
56
+ # if `output_size` is passed we force the interpolation output
57
+ # size and do not make use of `scale_factor=2`
58
+ if output_size is None:
59
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
60
+ else:
61
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
62
+
63
+ # If the input is bfloat16, we cast back to bfloat16
64
+ if dtype == torch.bfloat16:
65
+ hidden_states = hidden_states.to(dtype)
66
+
67
+ if self.use_conv:
68
+ if self.name == "conv":
69
+ hidden_states = self.conv(hidden_states)
70
+ else:
71
+ hidden_states = self.Conv2d_0(hidden_states)
72
+
73
+ return hidden_states
74
+
75
+
76
+ class Downsample2D(nn.Module):
77
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
78
+ super().__init__()
79
+ self.channels = channels
80
+ self.out_channels = out_channels or channels
81
+ self.use_conv = use_conv
82
+ self.padding = padding
83
+ stride = 2
84
+ self.name = name
85
+
86
+ if use_conv:
87
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ if name == "conv":
92
+ self.Conv2d_0 = conv
93
+ self.conv = conv
94
+ elif name == "Conv2d_0":
95
+ self.conv = conv
96
+ else:
97
+ self.conv = conv
98
+
99
+ def forward(self, hidden_states):
100
+ assert hidden_states.shape[1] == self.channels
101
+ if self.use_conv and self.padding == 0:
102
+ raise NotImplementedError
103
+
104
+ assert hidden_states.shape[1] == self.channels
105
+ hidden_states = self.conv(hidden_states)
106
+
107
+ return hidden_states
108
+
109
+
110
+ class ResnetBlock2D(nn.Module):
111
+ def __init__(
112
+ self,
113
+ *,
114
+ in_channels,
115
+ out_channels=None,
116
+ conv_shortcut=False,
117
+ dropout=0.0,
118
+ temb_channels=512,
119
+ groups=32,
120
+ groups_out=None,
121
+ pre_norm=True,
122
+ eps=1e-6,
123
+ non_linearity="swish",
124
+ time_embedding_norm="default",
125
+ output_scale_factor=1.0,
126
+ use_in_shortcut=None,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
142
+
143
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
144
+
145
+ if temb_channels is not None:
146
+ if self.time_embedding_norm == "default":
147
+ time_emb_proj_out_channels = out_channels
148
+ elif self.time_embedding_norm == "scale_shift":
149
+ time_emb_proj_out_channels = out_channels * 2
150
+ else:
151
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
152
+
153
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
154
+ else:
155
+ self.time_emb_proj = None
156
+
157
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
158
+ self.dropout = torch.nn.Dropout(dropout)
159
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
160
+
161
+ if non_linearity == "swish":
162
+ self.nonlinearity = lambda x: F.silu(x)
163
+ elif non_linearity == "mish":
164
+ self.nonlinearity = Mish()
165
+ elif non_linearity == "silu":
166
+ self.nonlinearity = nn.SiLU()
167
+
168
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
169
+
170
+ self.conv_shortcut = None
171
+ if self.use_in_shortcut:
172
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
173
+
174
+ def forward(self, input_tensor, temb):
175
+ hidden_states = input_tensor
176
+
177
+ hidden_states = self.norm1(hidden_states)
178
+ hidden_states = self.nonlinearity(hidden_states)
179
+
180
+ hidden_states = self.conv1(hidden_states)
181
+
182
+ if temb is not None:
183
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
184
+
185
+ if temb is not None and self.time_embedding_norm == "default":
186
+ hidden_states = hidden_states + temb
187
+
188
+ hidden_states = self.norm2(hidden_states)
189
+
190
+ if temb is not None and self.time_embedding_norm == "scale_shift":
191
+ scale, shift = torch.chunk(temb, 2, dim=1)
192
+ hidden_states = hidden_states * (1 + scale) + shift
193
+
194
+ hidden_states = self.nonlinearity(hidden_states)
195
+
196
+ hidden_states = self.dropout(hidden_states)
197
+ hidden_states = self.conv2(hidden_states)
198
+
199
+ if self.conv_shortcut is not None:
200
+ input_tensor = self.conv_shortcut(input_tensor)
201
+
202
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
203
+
204
+ return output_tensor
205
+
206
+
207
+ class Mish(torch.nn.Module):
208
+ def forward(self, hidden_states):
209
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
vid2vid_zero/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention_2d import Transformer2DModel
7
+ from .resnet_2d import Downsample2D, ResnetBlock2D, Upsample2D
8
+
9
+
10
+ def get_down_block(
11
+ down_block_type,
12
+ num_layers,
13
+ in_channels,
14
+ out_channels,
15
+ temb_channels,
16
+ add_downsample,
17
+ resnet_eps,
18
+ resnet_act_fn,
19
+ attn_num_head_channels,
20
+ resnet_groups=None,
21
+ cross_attention_dim=None,
22
+ downsample_padding=None,
23
+ dual_cross_attention=False,
24
+ use_linear_projection=False,
25
+ only_cross_attention=False,
26
+ upcast_attention=False,
27
+ resnet_time_scale_shift="default",
28
+ use_sc_attn=False,
29
+ use_st_attn=False,
30
+ ):
31
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
32
+ if down_block_type == "DownBlock2D":
33
+ return DownBlock2D(
34
+ num_layers=num_layers,
35
+ in_channels=in_channels,
36
+ out_channels=out_channels,
37
+ temb_channels=temb_channels,
38
+ add_downsample=add_downsample,
39
+ resnet_eps=resnet_eps,
40
+ resnet_act_fn=resnet_act_fn,
41
+ resnet_groups=resnet_groups,
42
+ downsample_padding=downsample_padding,
43
+ resnet_time_scale_shift=resnet_time_scale_shift,
44
+ )
45
+ elif down_block_type == "CrossAttnDownBlock2D":
46
+ if cross_attention_dim is None:
47
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
48
+ return CrossAttnDownBlock2D(
49
+ num_layers=num_layers,
50
+ in_channels=in_channels,
51
+ out_channels=out_channels,
52
+ temb_channels=temb_channels,
53
+ add_downsample=add_downsample,
54
+ resnet_eps=resnet_eps,
55
+ resnet_act_fn=resnet_act_fn,
56
+ resnet_groups=resnet_groups,
57
+ downsample_padding=downsample_padding,
58
+ cross_attention_dim=cross_attention_dim,
59
+ attn_num_head_channels=attn_num_head_channels,
60
+ dual_cross_attention=dual_cross_attention,
61
+ use_linear_projection=use_linear_projection,
62
+ only_cross_attention=only_cross_attention,
63
+ upcast_attention=upcast_attention,
64
+ resnet_time_scale_shift=resnet_time_scale_shift,
65
+ use_sc_attn=use_sc_attn,
66
+ use_st_attn=use_st_attn,
67
+ )
68
+ raise ValueError(f"{down_block_type} does not exist.")
69
+
70
+
71
+ def get_up_block(
72
+ up_block_type,
73
+ num_layers,
74
+ in_channels,
75
+ out_channels,
76
+ prev_output_channel,
77
+ temb_channels,
78
+ add_upsample,
79
+ resnet_eps,
80
+ resnet_act_fn,
81
+ attn_num_head_channels,
82
+ resnet_groups=None,
83
+ cross_attention_dim=None,
84
+ dual_cross_attention=False,
85
+ use_linear_projection=False,
86
+ only_cross_attention=False,
87
+ upcast_attention=False,
88
+ resnet_time_scale_shift="default",
89
+ use_sc_attn=False,
90
+ use_st_attn=False,
91
+ ):
92
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
93
+ if up_block_type == "UpBlock2D":
94
+ return UpBlock2D(
95
+ num_layers=num_layers,
96
+ in_channels=in_channels,
97
+ out_channels=out_channels,
98
+ prev_output_channel=prev_output_channel,
99
+ temb_channels=temb_channels,
100
+ add_upsample=add_upsample,
101
+ resnet_eps=resnet_eps,
102
+ resnet_act_fn=resnet_act_fn,
103
+ resnet_groups=resnet_groups,
104
+ resnet_time_scale_shift=resnet_time_scale_shift,
105
+ )
106
+ elif up_block_type == "CrossAttnUpBlock2D":
107
+ if cross_attention_dim is None:
108
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
109
+ return CrossAttnUpBlock2D(
110
+ num_layers=num_layers,
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ prev_output_channel=prev_output_channel,
114
+ temb_channels=temb_channels,
115
+ add_upsample=add_upsample,
116
+ resnet_eps=resnet_eps,
117
+ resnet_act_fn=resnet_act_fn,
118
+ resnet_groups=resnet_groups,
119
+ cross_attention_dim=cross_attention_dim,
120
+ attn_num_head_channels=attn_num_head_channels,
121
+ dual_cross_attention=dual_cross_attention,
122
+ use_linear_projection=use_linear_projection,
123
+ only_cross_attention=only_cross_attention,
124
+ upcast_attention=upcast_attention,
125
+ resnet_time_scale_shift=resnet_time_scale_shift,
126
+ use_sc_attn=use_sc_attn,
127
+ use_st_attn=use_st_attn,
128
+ )
129
+ raise ValueError(f"{up_block_type} does not exist.")
130
+
131
+
132
+ class UNetMidBlock2DCrossAttn(nn.Module):
133
+ def __init__(
134
+ self,
135
+ in_channels: int,
136
+ temb_channels: int,
137
+ dropout: float = 0.0,
138
+ num_layers: int = 1,
139
+ resnet_eps: float = 1e-6,
140
+ resnet_time_scale_shift: str = "default",
141
+ resnet_act_fn: str = "swish",
142
+ resnet_groups: int = 32,
143
+ resnet_pre_norm: bool = True,
144
+ attn_num_head_channels=1,
145
+ output_scale_factor=1.0,
146
+ cross_attention_dim=1280,
147
+ dual_cross_attention=False,
148
+ use_linear_projection=False,
149
+ upcast_attention=False,
150
+ use_sc_attn=False,
151
+ use_st_attn=False,
152
+ ):
153
+ super().__init__()
154
+
155
+ self.has_cross_attention = True
156
+ self.attn_num_head_channels = attn_num_head_channels
157
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
158
+
159
+ # there is always at least one resnet
160
+ resnets = [
161
+ ResnetBlock2D(
162
+ in_channels=in_channels,
163
+ out_channels=in_channels,
164
+ temb_channels=temb_channels,
165
+ eps=resnet_eps,
166
+ groups=resnet_groups,
167
+ dropout=dropout,
168
+ time_embedding_norm=resnet_time_scale_shift,
169
+ non_linearity=resnet_act_fn,
170
+ output_scale_factor=output_scale_factor,
171
+ pre_norm=resnet_pre_norm,
172
+ )
173
+ ]
174
+ attentions = []
175
+
176
+ for _ in range(num_layers):
177
+ if dual_cross_attention:
178
+ raise NotImplementedError
179
+ attentions.append(
180
+ Transformer2DModel(
181
+ attn_num_head_channels,
182
+ in_channels // attn_num_head_channels,
183
+ in_channels=in_channels,
184
+ num_layers=1,
185
+ cross_attention_dim=cross_attention_dim,
186
+ norm_num_groups=resnet_groups,
187
+ use_linear_projection=use_linear_projection,
188
+ upcast_attention=upcast_attention,
189
+ use_sc_attn=use_sc_attn,
190
+ use_st_attn=True if (use_st_attn and _ == 0) else False,
191
+ )
192
+ )
193
+ resnets.append(
194
+ ResnetBlock2D(
195
+ in_channels=in_channels,
196
+ out_channels=in_channels,
197
+ temb_channels=temb_channels,
198
+ eps=resnet_eps,
199
+ groups=resnet_groups,
200
+ dropout=dropout,
201
+ time_embedding_norm=resnet_time_scale_shift,
202
+ non_linearity=resnet_act_fn,
203
+ output_scale_factor=output_scale_factor,
204
+ pre_norm=resnet_pre_norm,
205
+ )
206
+ )
207
+
208
+ self.attentions = nn.ModuleList(attentions)
209
+ self.resnets = nn.ModuleList(resnets)
210
+
211
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, normal_infer=False):
212
+ hidden_states = self.resnets[0](hidden_states, temb)
213
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
214
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample
215
+ hidden_states = resnet(hidden_states, temb)
216
+
217
+ return hidden_states
218
+
219
+
220
+ class CrossAttnDownBlock2D(nn.Module):
221
+ def __init__(
222
+ self,
223
+ in_channels: int,
224
+ out_channels: int,
225
+ temb_channels: int,
226
+ dropout: float = 0.0,
227
+ num_layers: int = 1,
228
+ resnet_eps: float = 1e-6,
229
+ resnet_time_scale_shift: str = "default",
230
+ resnet_act_fn: str = "swish",
231
+ resnet_groups: int = 32,
232
+ resnet_pre_norm: bool = True,
233
+ attn_num_head_channels=1,
234
+ cross_attention_dim=1280,
235
+ output_scale_factor=1.0,
236
+ downsample_padding=1,
237
+ add_downsample=True,
238
+ dual_cross_attention=False,
239
+ use_linear_projection=False,
240
+ only_cross_attention=False,
241
+ upcast_attention=False,
242
+ use_sc_attn=False,
243
+ use_st_attn=False,
244
+ ):
245
+ super().__init__()
246
+ resnets = []
247
+ attentions = []
248
+
249
+ self.has_cross_attention = True
250
+ self.attn_num_head_channels = attn_num_head_channels
251
+
252
+ for i in range(num_layers):
253
+ in_channels = in_channels if i == 0 else out_channels
254
+ resnets.append(
255
+ ResnetBlock2D(
256
+ in_channels=in_channels,
257
+ out_channels=out_channels,
258
+ temb_channels=temb_channels,
259
+ eps=resnet_eps,
260
+ groups=resnet_groups,
261
+ dropout=dropout,
262
+ time_embedding_norm=resnet_time_scale_shift,
263
+ non_linearity=resnet_act_fn,
264
+ output_scale_factor=output_scale_factor,
265
+ pre_norm=resnet_pre_norm,
266
+ )
267
+ )
268
+ if dual_cross_attention:
269
+ raise NotImplementedError
270
+ attentions.append(
271
+ Transformer2DModel(
272
+ attn_num_head_channels,
273
+ out_channels // attn_num_head_channels,
274
+ in_channels=out_channels,
275
+ num_layers=1,
276
+ cross_attention_dim=cross_attention_dim,
277
+ norm_num_groups=resnet_groups,
278
+ use_linear_projection=use_linear_projection,
279
+ only_cross_attention=only_cross_attention,
280
+ upcast_attention=upcast_attention,
281
+ use_sc_attn=use_sc_attn,
282
+ use_st_attn=True if (use_st_attn and i == 0) else False,
283
+ )
284
+ )
285
+ self.attentions = nn.ModuleList(attentions)
286
+ self.resnets = nn.ModuleList(resnets)
287
+
288
+ if add_downsample:
289
+ self.downsamplers = nn.ModuleList(
290
+ [
291
+ Downsample2D(
292
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
293
+ )
294
+ ]
295
+ )
296
+ else:
297
+ self.downsamplers = None
298
+
299
+ self.gradient_checkpointing = False
300
+
301
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, normal_infer=False):
302
+ output_states = ()
303
+
304
+ for resnet, attn in zip(self.resnets, self.attentions):
305
+ if self.training and self.gradient_checkpointing:
306
+
307
+ def create_custom_forward(module, return_dict=None, normal_infer=False):
308
+ def custom_forward(*inputs):
309
+ if return_dict is not None:
310
+ return module(*inputs, return_dict=return_dict, normal_infer=normal_infer)
311
+ else:
312
+ return module(*inputs)
313
+
314
+ return custom_forward
315
+
316
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
317
+ hidden_states = torch.utils.checkpoint.checkpoint(
318
+ create_custom_forward(attn, return_dict=False, normal_infer=normal_infer),
319
+ hidden_states,
320
+ encoder_hidden_states,
321
+ )[0]
322
+ else:
323
+ hidden_states = resnet(hidden_states, temb)
324
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample
325
+
326
+ output_states += (hidden_states,)
327
+
328
+ if self.downsamplers is not None:
329
+ for downsampler in self.downsamplers:
330
+ hidden_states = downsampler(hidden_states)
331
+
332
+ output_states += (hidden_states,)
333
+
334
+ return hidden_states, output_states
335
+
336
+
337
+ class DownBlock2D(nn.Module):
338
+ def __init__(
339
+ self,
340
+ in_channels: int,
341
+ out_channels: int,
342
+ temb_channels: int,
343
+ dropout: float = 0.0,
344
+ num_layers: int = 1,
345
+ resnet_eps: float = 1e-6,
346
+ resnet_time_scale_shift: str = "default",
347
+ resnet_act_fn: str = "swish",
348
+ resnet_groups: int = 32,
349
+ resnet_pre_norm: bool = True,
350
+ output_scale_factor=1.0,
351
+ add_downsample=True,
352
+ downsample_padding=1,
353
+ ):
354
+ super().__init__()
355
+ resnets = []
356
+
357
+ for i in range(num_layers):
358
+ in_channels = in_channels if i == 0 else out_channels
359
+ resnets.append(
360
+ ResnetBlock2D(
361
+ in_channels=in_channels,
362
+ out_channels=out_channels,
363
+ temb_channels=temb_channels,
364
+ eps=resnet_eps,
365
+ groups=resnet_groups,
366
+ dropout=dropout,
367
+ time_embedding_norm=resnet_time_scale_shift,
368
+ non_linearity=resnet_act_fn,
369
+ output_scale_factor=output_scale_factor,
370
+ pre_norm=resnet_pre_norm,
371
+ )
372
+ )
373
+
374
+ self.resnets = nn.ModuleList(resnets)
375
+
376
+ if add_downsample:
377
+ self.downsamplers = nn.ModuleList(
378
+ [
379
+ Downsample2D(
380
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
381
+ )
382
+ ]
383
+ )
384
+ else:
385
+ self.downsamplers = None
386
+
387
+ self.gradient_checkpointing = False
388
+
389
+ def forward(self, hidden_states, temb=None):
390
+ output_states = ()
391
+
392
+ for resnet in self.resnets:
393
+ if self.training and self.gradient_checkpointing:
394
+
395
+ def create_custom_forward(module):
396
+ def custom_forward(*inputs):
397
+ return module(*inputs)
398
+
399
+ return custom_forward
400
+
401
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
402
+ else:
403
+ hidden_states = resnet(hidden_states, temb)
404
+
405
+ output_states += (hidden_states,)
406
+
407
+ if self.downsamplers is not None:
408
+ for downsampler in self.downsamplers:
409
+ hidden_states = downsampler(hidden_states)
410
+
411
+ output_states += (hidden_states,)
412
+
413
+ return hidden_states, output_states
414
+
415
+
416
+ class CrossAttnUpBlock2D(nn.Module):
417
+ def __init__(
418
+ self,
419
+ in_channels: int,
420
+ out_channels: int,
421
+ prev_output_channel: int,
422
+ temb_channels: int,
423
+ dropout: float = 0.0,
424
+ num_layers: int = 1,
425
+ resnet_eps: float = 1e-6,
426
+ resnet_time_scale_shift: str = "default",
427
+ resnet_act_fn: str = "swish",
428
+ resnet_groups: int = 32,
429
+ resnet_pre_norm: bool = True,
430
+ attn_num_head_channels=1,
431
+ cross_attention_dim=1280,
432
+ output_scale_factor=1.0,
433
+ add_upsample=True,
434
+ dual_cross_attention=False,
435
+ use_linear_projection=False,
436
+ only_cross_attention=False,
437
+ upcast_attention=False,
438
+ use_sc_attn=False,
439
+ use_st_attn=False,
440
+ ):
441
+ super().__init__()
442
+ resnets = []
443
+ attentions = []
444
+
445
+ self.has_cross_attention = True
446
+ self.attn_num_head_channels = attn_num_head_channels
447
+
448
+ for i in range(num_layers):
449
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
450
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
451
+
452
+ resnets.append(
453
+ ResnetBlock2D(
454
+ in_channels=resnet_in_channels + res_skip_channels,
455
+ out_channels=out_channels,
456
+ temb_channels=temb_channels,
457
+ eps=resnet_eps,
458
+ groups=resnet_groups,
459
+ dropout=dropout,
460
+ time_embedding_norm=resnet_time_scale_shift,
461
+ non_linearity=resnet_act_fn,
462
+ output_scale_factor=output_scale_factor,
463
+ pre_norm=resnet_pre_norm,
464
+ )
465
+ )
466
+ if dual_cross_attention:
467
+ raise NotImplementedError
468
+ attentions.append(
469
+ Transformer2DModel(
470
+ attn_num_head_channels,
471
+ out_channels // attn_num_head_channels,
472
+ in_channels=out_channels,
473
+ num_layers=1,
474
+ cross_attention_dim=cross_attention_dim,
475
+ norm_num_groups=resnet_groups,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention,
478
+ upcast_attention=upcast_attention,
479
+ use_sc_attn=use_sc_attn,
480
+ use_st_attn=True if (use_st_attn and i == 0) else False,
481
+ )
482
+ )
483
+
484
+ self.attentions = nn.ModuleList(attentions)
485
+ self.resnets = nn.ModuleList(resnets)
486
+
487
+ if add_upsample:
488
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
489
+ else:
490
+ self.upsamplers = None
491
+
492
+ self.gradient_checkpointing = False
493
+
494
+ def forward(
495
+ self,
496
+ hidden_states,
497
+ res_hidden_states_tuple,
498
+ temb=None,
499
+ encoder_hidden_states=None,
500
+ upsample_size=None,
501
+ attention_mask=None,
502
+ normal_infer=False,
503
+ ):
504
+ for resnet, attn in zip(self.resnets, self.attentions):
505
+ # pop res hidden states
506
+ res_hidden_states = res_hidden_states_tuple[-1]
507
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
508
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
509
+
510
+ if self.training and self.gradient_checkpointing:
511
+
512
+ def create_custom_forward(module, return_dict=None, normal_infer=False):
513
+ def custom_forward(*inputs):
514
+ if return_dict is not None:
515
+ return module(*inputs, return_dict=return_dict, normal_infer=normal_infer)
516
+ else:
517
+ return module(*inputs)
518
+
519
+ return custom_forward
520
+
521
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
522
+ hidden_states = torch.utils.checkpoint.checkpoint(
523
+ create_custom_forward(attn, return_dict=False, normal_infer=normal_infer),
524
+ hidden_states,
525
+ encoder_hidden_states,
526
+ )[0]
527
+ else:
528
+ hidden_states = resnet(hidden_states, temb)
529
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, normal_infer=normal_infer).sample
530
+
531
+ if self.upsamplers is not None:
532
+ for upsampler in self.upsamplers:
533
+ hidden_states = upsampler(hidden_states, upsample_size)
534
+
535
+ return hidden_states
536
+
537
+
538
+ class UpBlock2D(nn.Module):
539
+ def __init__(
540
+ self,
541
+ in_channels: int,
542
+ prev_output_channel: int,
543
+ out_channels: int,
544
+ temb_channels: int,
545
+ dropout: float = 0.0,
546
+ num_layers: int = 1,
547
+ resnet_eps: float = 1e-6,
548
+ resnet_time_scale_shift: str = "default",
549
+ resnet_act_fn: str = "swish",
550
+ resnet_groups: int = 32,
551
+ resnet_pre_norm: bool = True,
552
+ output_scale_factor=1.0,
553
+ add_upsample=True,
554
+ ):
555
+ super().__init__()
556
+ resnets = []
557
+
558
+ for i in range(num_layers):
559
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
560
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
561
+
562
+ resnets.append(
563
+ ResnetBlock2D(
564
+ in_channels=resnet_in_channels + res_skip_channels,
565
+ out_channels=out_channels,
566
+ temb_channels=temb_channels,
567
+ eps=resnet_eps,
568
+ groups=resnet_groups,
569
+ dropout=dropout,
570
+ time_embedding_norm=resnet_time_scale_shift,
571
+ non_linearity=resnet_act_fn,
572
+ output_scale_factor=output_scale_factor,
573
+ pre_norm=resnet_pre_norm,
574
+ )
575
+ )
576
+
577
+ self.resnets = nn.ModuleList(resnets)
578
+
579
+ if add_upsample:
580
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
581
+ else:
582
+ self.upsamplers = None
583
+
584
+ self.gradient_checkpointing = False
585
+
586
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
587
+ for resnet in self.resnets:
588
+ # pop res hidden states
589
+ res_hidden_states = res_hidden_states_tuple[-1]
590
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
591
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
592
+
593
+ if self.training and self.gradient_checkpointing:
594
+
595
+ def create_custom_forward(module):
596
+ def custom_forward(*inputs):
597
+ return module(*inputs)
598
+
599
+ return custom_forward
600
+
601
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
602
+ else:
603
+ hidden_states = resnet(hidden_states, temb)
604
+
605
+ if self.upsamplers is not None:
606
+ for upsampler in self.upsamplers:
607
+ hidden_states = upsampler(hidden_states, upsample_size)
608
+
609
+ return hidden_states
vid2vid_zero/models/unet_2d_condition.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ import os, json
4
+ from dataclasses import dataclass
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.modeling_utils import ModelMixin
13
+ from diffusers.utils import BaseOutput, logging
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from .unet_2d_blocks import (
16
+ CrossAttnDownBlock2D,
17
+ CrossAttnUpBlock2D,
18
+ DownBlock2D,
19
+ UNetMidBlock2DCrossAttn,
20
+ # UNetMidBlock2DSimpleCrossAttn,
21
+ UpBlock2D,
22
+ get_down_block,
23
+ get_up_block,
24
+ )
25
+ from .resnet_2d import InflatedConv3d
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class UNet2DConditionOutput(BaseOutput):
32
+ """
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
35
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
42
+ r"""
43
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
44
+ and returns sample shaped output.
45
+
46
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
47
+ implements for all the models (such as downloading or saving, etc.)
48
+
49
+ Parameters:
50
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
51
+ Height and width of input/output sample.
52
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
53
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
54
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
55
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
56
+ Whether to flip the sin to cos in the time embedding.
57
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
58
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
59
+ The tuple of downsample blocks to use.
60
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
61
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
62
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
63
+ The tuple of upsample blocks to use.
64
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
65
+ The tuple of output channels for each block.
66
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
67
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
68
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
69
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
70
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
71
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
72
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
73
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
74
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
75
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
76
+ class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
77
+ summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
78
+ """
79
+
80
+ _supports_gradient_checkpointing = True
81
+
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ sample_size: Optional[int] = None,
86
+ in_channels: int = 4,
87
+ out_channels: int = 4,
88
+ center_input_sample: bool = False,
89
+ flip_sin_to_cos: bool = True,
90
+ freq_shift: int = 0,
91
+ down_block_types: Tuple[str] = (
92
+ "CrossAttnDownBlock2D",
93
+ "CrossAttnDownBlock2D",
94
+ "CrossAttnDownBlock2D",
95
+ "DownBlock2D",
96
+ ),
97
+ mid_block_type: str = "UNetMidBlock2DCrossAttn",
98
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
99
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
100
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
101
+ layers_per_block: int = 2,
102
+ downsample_padding: int = 1,
103
+ mid_block_scale_factor: float = 1,
104
+ act_fn: str = "silu",
105
+ norm_num_groups: int = 32,
106
+ norm_eps: float = 1e-5,
107
+ cross_attention_dim: int = 1280,
108
+ attention_head_dim: Union[int, Tuple[int]] = 8,
109
+ dual_cross_attention: bool = False,
110
+ use_linear_projection: bool = False,
111
+ class_embed_type: Optional[str] = None,
112
+ num_class_embeds: Optional[int] = None,
113
+ upcast_attention: bool = False,
114
+ resnet_time_scale_shift: str = "default",
115
+ use_sc_attn: bool = False,
116
+ use_st_attn: bool = False,
117
+ st_attn_idx: int = None,
118
+ ):
119
+ super().__init__()
120
+
121
+ self.sample_size = sample_size
122
+ time_embed_dim = block_out_channels[0] * 4
123
+
124
+ # input
125
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
126
+
127
+ # time
128
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
129
+ timestep_input_dim = block_out_channels[0]
130
+
131
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
132
+
133
+ # class embedding
134
+ if class_embed_type is None and num_class_embeds is not None:
135
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
136
+ elif class_embed_type == "timestep":
137
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
138
+ elif class_embed_type == "identity":
139
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
140
+ else:
141
+ self.class_embedding = None
142
+
143
+ self.down_blocks = nn.ModuleList([])
144
+ self.mid_block = None
145
+ self.up_blocks = nn.ModuleList([])
146
+
147
+ if isinstance(only_cross_attention, bool):
148
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
149
+
150
+ if isinstance(attention_head_dim, int):
151
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
152
+
153
+ # down
154
+ output_channel = block_out_channels[0]
155
+ for i, down_block_type in enumerate(down_block_types):
156
+ input_channel = output_channel
157
+ output_channel = block_out_channels[i]
158
+ is_final_block = i == len(block_out_channels) - 1
159
+
160
+ down_block = get_down_block(
161
+ down_block_type,
162
+ num_layers=layers_per_block,
163
+ in_channels=input_channel,
164
+ out_channels=output_channel,
165
+ temb_channels=time_embed_dim,
166
+ add_downsample=not is_final_block,
167
+ resnet_eps=norm_eps,
168
+ resnet_act_fn=act_fn,
169
+ resnet_groups=norm_num_groups,
170
+ cross_attention_dim=cross_attention_dim,
171
+ attn_num_head_channels=attention_head_dim[i],
172
+ downsample_padding=downsample_padding,
173
+ dual_cross_attention=dual_cross_attention,
174
+ use_linear_projection=use_linear_projection,
175
+ only_cross_attention=only_cross_attention[i],
176
+ upcast_attention=upcast_attention,
177
+ resnet_time_scale_shift=resnet_time_scale_shift,
178
+ use_sc_attn=use_sc_attn,
179
+ # idx range from 0 to 2, i.e., ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']
180
+ use_st_attn=True if (use_st_attn and i == st_attn_idx) else False,
181
+ )
182
+ self.down_blocks.append(down_block)
183
+
184
+ # mid
185
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
186
+ self.mid_block = UNetMidBlock2DCrossAttn(
187
+ in_channels=block_out_channels[-1],
188
+ temb_channels=time_embed_dim,
189
+ resnet_eps=norm_eps,
190
+ resnet_act_fn=act_fn,
191
+ output_scale_factor=mid_block_scale_factor,
192
+ resnet_time_scale_shift=resnet_time_scale_shift,
193
+ cross_attention_dim=cross_attention_dim,
194
+ attn_num_head_channels=attention_head_dim[-1],
195
+ resnet_groups=norm_num_groups,
196
+ dual_cross_attention=dual_cross_attention,
197
+ use_linear_projection=use_linear_projection,
198
+ upcast_attention=upcast_attention,
199
+ use_sc_attn=use_sc_attn,
200
+ use_st_attn=use_st_attn,
201
+ )
202
+ else:
203
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
204
+
205
+ # count how many layers upsample the videos
206
+ self.num_upsamplers = 0
207
+
208
+ # up
209
+ reversed_block_out_channels = list(reversed(block_out_channels))
210
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
211
+ only_cross_attention = list(reversed(only_cross_attention))
212
+ output_channel = reversed_block_out_channels[0]
213
+ for i, up_block_type in enumerate(up_block_types):
214
+ is_final_block = i == len(block_out_channels) - 1
215
+
216
+ prev_output_channel = output_channel
217
+ output_channel = reversed_block_out_channels[i]
218
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
219
+
220
+ # add upsample block for all BUT final layer
221
+ if not is_final_block:
222
+ add_upsample = True
223
+ self.num_upsamplers += 1
224
+ else:
225
+ add_upsample = False
226
+
227
+ up_block = get_up_block(
228
+ up_block_type,
229
+ num_layers=layers_per_block + 1,
230
+ in_channels=input_channel,
231
+ out_channels=output_channel,
232
+ prev_output_channel=prev_output_channel,
233
+ temb_channels=time_embed_dim,
234
+ add_upsample=add_upsample,
235
+ resnet_eps=norm_eps,
236
+ resnet_act_fn=act_fn,
237
+ resnet_groups=norm_num_groups,
238
+ cross_attention_dim=cross_attention_dim,
239
+ attn_num_head_channels=reversed_attention_head_dim[i],
240
+ dual_cross_attention=dual_cross_attention,
241
+ use_linear_projection=use_linear_projection,
242
+ only_cross_attention=only_cross_attention[i],
243
+ upcast_attention=upcast_attention,
244
+ resnet_time_scale_shift=resnet_time_scale_shift,
245
+ use_sc_attn=use_sc_attn,
246
+ # idx range from 0 to 2, i.e., ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']
247
+ use_st_attn=True if (use_st_attn and i-1 == st_attn_idx) else False,
248
+ )
249
+ self.up_blocks.append(up_block)
250
+ prev_output_channel = output_channel
251
+
252
+ # out
253
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
254
+ self.conv_act = nn.SiLU()
255
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
256
+
257
+ def set_attention_slice(self, slice_size):
258
+ r"""
259
+ Enable sliced attention computation.
260
+
261
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
262
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
263
+
264
+ Args:
265
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
266
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
267
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
268
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
269
+ must be a multiple of `slice_size`.
270
+ """
271
+ sliceable_head_dims = []
272
+
273
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
274
+ if hasattr(module, "set_attention_slice"):
275
+ sliceable_head_dims.append(module.sliceable_head_dim)
276
+
277
+ for child in module.children():
278
+ fn_recursive_retrieve_slicable_dims(child)
279
+
280
+ # retrieve number of attention layers
281
+ for module in self.children():
282
+ fn_recursive_retrieve_slicable_dims(module)
283
+
284
+ num_slicable_layers = len(sliceable_head_dims)
285
+
286
+ if slice_size == "auto":
287
+ # half the attention head size is usually a good trade-off between
288
+ # speed and memory
289
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
290
+ elif slice_size == "max":
291
+ # make smallest slice possible
292
+ slice_size = num_slicable_layers * [1]
293
+
294
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
295
+
296
+ if len(slice_size) != len(sliceable_head_dims):
297
+ raise ValueError(
298
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
299
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
300
+ )
301
+
302
+ for i in range(len(slice_size)):
303
+ size = slice_size[i]
304
+ dim = sliceable_head_dims[i]
305
+ if size is not None and size > dim:
306
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
307
+
308
+ # Recursively walk through all the children.
309
+ # Any children which exposes the set_attention_slice method
310
+ # gets the message
311
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
312
+ if hasattr(module, "set_attention_slice"):
313
+ module.set_attention_slice(slice_size.pop())
314
+
315
+ for child in module.children():
316
+ fn_recursive_set_attention_slice(child, slice_size)
317
+
318
+ reversed_slice_size = list(reversed(slice_size))
319
+ for module in self.children():
320
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
321
+
322
+ def _set_gradient_checkpointing(self, module, value=False):
323
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
324
+ module.gradient_checkpointing = value
325
+
326
+ def forward(
327
+ self,
328
+ sample: torch.FloatTensor,
329
+ timestep: Union[torch.Tensor, float, int],
330
+ encoder_hidden_states: torch.Tensor,
331
+ class_labels: Optional[torch.Tensor] = None,
332
+ attention_mask: Optional[torch.Tensor] = None,
333
+ return_dict: bool = True,
334
+ normal_infer: bool = False,
335
+ ) -> Union[UNet2DConditionOutput, Tuple]:
336
+ r"""
337
+ Args:
338
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
339
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
340
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
341
+ return_dict (`bool`, *optional*, defaults to `True`):
342
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
343
+
344
+ Returns:
345
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
346
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
347
+ returning a tuple, the first element is the sample tensor.
348
+ """
349
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
350
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
351
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
352
+ # on the fly if necessary.
353
+ default_overall_up_factor = 2**self.num_upsamplers
354
+
355
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
356
+ forward_upsample_size = False
357
+ upsample_size = None
358
+
359
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
360
+ logger.info("Forward upsample size to force interpolation output size.")
361
+ forward_upsample_size = True
362
+
363
+ # prepare attention_mask
364
+ if attention_mask is not None:
365
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
366
+ attention_mask = attention_mask.unsqueeze(1)
367
+
368
+ # center input if necessary
369
+ if self.config.center_input_sample:
370
+ sample = 2 * sample - 1.0
371
+
372
+ # time
373
+ timesteps = timestep
374
+ if not torch.is_tensor(timesteps):
375
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
376
+ # This would be a good case for the `match` statement (Python 3.10+)
377
+ is_mps = sample.device.type == "mps"
378
+ if isinstance(timestep, float):
379
+ dtype = torch.float32 if is_mps else torch.float64
380
+ else:
381
+ dtype = torch.int32 if is_mps else torch.int64
382
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
383
+ elif len(timesteps.shape) == 0:
384
+ timesteps = timesteps[None].to(sample.device)
385
+
386
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
387
+ timesteps = timesteps.expand(sample.shape[0])
388
+
389
+ t_emb = self.time_proj(timesteps)
390
+
391
+ # timesteps does not contain any weights and will always return f32 tensors
392
+ # but time_embedding might actually be running in fp16. so we need to cast here.
393
+ # there might be better ways to encapsulate this.
394
+ t_emb = t_emb.to(dtype=self.dtype)
395
+ emb = self.time_embedding(t_emb)
396
+
397
+ if self.class_embedding is not None:
398
+ if class_labels is None:
399
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
400
+
401
+ if self.config.class_embed_type == "timestep":
402
+ class_labels = self.time_proj(class_labels)
403
+
404
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
405
+ emb = emb + class_emb
406
+
407
+ # pre-process
408
+ sample = self.conv_in(sample)
409
+
410
+ # down
411
+ down_block_res_samples = (sample,)
412
+ for downsample_block in self.down_blocks:
413
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
414
+ sample, res_samples = downsample_block(
415
+ hidden_states=sample,
416
+ temb=emb,
417
+ encoder_hidden_states=encoder_hidden_states,
418
+ attention_mask=attention_mask,
419
+ normal_infer=normal_infer,
420
+ )
421
+ else:
422
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
423
+
424
+ down_block_res_samples += res_samples
425
+
426
+ # mid
427
+ sample = self.mid_block(
428
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,
429
+ normal_infer=normal_infer,
430
+ )
431
+
432
+ # up
433
+ for i, upsample_block in enumerate(self.up_blocks):
434
+ is_final_block = i == len(self.up_blocks) - 1
435
+
436
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
437
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
438
+
439
+ # if we have not reached the final block and need to forward the
440
+ # upsample size, we do it here
441
+ if not is_final_block and forward_upsample_size:
442
+ upsample_size = down_block_res_samples[-1].shape[2:]
443
+
444
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
445
+ sample = upsample_block(
446
+ hidden_states=sample,
447
+ temb=emb,
448
+ res_hidden_states_tuple=res_samples,
449
+ encoder_hidden_states=encoder_hidden_states,
450
+ upsample_size=upsample_size,
451
+ attention_mask=attention_mask,
452
+ normal_infer=normal_infer,
453
+ )
454
+ else:
455
+ sample = upsample_block(
456
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
457
+ )
458
+ # post-process
459
+ sample = self.conv_norm_out(sample)
460
+ sample = self.conv_act(sample)
461
+ sample = self.conv_out(sample)
462
+
463
+ if not return_dict:
464
+ return (sample,)
465
+
466
+ return UNet2DConditionOutput(sample=sample)
467
+
468
+ @classmethod
469
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
470
+ r"""
471
+ for gradio demo
472
+ """
473
+
474
+ import diffusers
475
+ __version__ = diffusers.__version__
476
+ from diffusers.utils import (
477
+ CONFIG_NAME,
478
+ DIFFUSERS_CACHE,
479
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
480
+ SAFETENSORS_WEIGHTS_NAME,
481
+ WEIGHTS_NAME,
482
+ is_accelerate_available,
483
+ is_safetensors_available,
484
+ is_torch_version,
485
+ logging,
486
+ )
487
+
488
+ if is_torch_version(">=", "1.9.0"):
489
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
490
+ else:
491
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
492
+
493
+
494
+ if is_accelerate_available():
495
+ import accelerate
496
+ from accelerate.utils import set_module_tensor_to_device
497
+ from accelerate.utils.versions import is_torch_version
498
+
499
+ if is_safetensors_available():
500
+ import safetensors
501
+
502
+ from diffusers.modeling_utils import load_state_dict
503
+
504
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
505
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
506
+ force_download = kwargs.pop("force_download", False)
507
+ resume_download = kwargs.pop("resume_download", False)
508
+ proxies = kwargs.pop("proxies", None)
509
+ output_loading_info = kwargs.pop("output_loading_info", False)
510
+ local_files_only = kwargs.pop("local_files_only", False)
511
+ use_auth_token = kwargs.pop("use_auth_token", None)
512
+ revision = kwargs.pop("revision", None)
513
+ torch_dtype = kwargs.pop("torch_dtype", None)
514
+ subfolder = kwargs.pop("subfolder", None)
515
+ device_map = kwargs.pop("device_map", None)
516
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
517
+ # custom arg
518
+ use_sc_attn = kwargs.pop("use_sc_attn", True)
519
+ use_st_attn = kwargs.pop("use_st_attn", True)
520
+ st_attn_idx = kwargs.pop("st_attn_idx", 0)
521
+
522
+ if low_cpu_mem_usage and not is_accelerate_available():
523
+ low_cpu_mem_usage = False
524
+ logger.warning(
525
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
526
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
527
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
528
+ " install accelerate\n```\n."
529
+ )
530
+
531
+ if device_map is not None and not is_accelerate_available():
532
+ raise NotImplementedError(
533
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
534
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
535
+ )
536
+
537
+ # Check if we can handle device_map and dispatching the weights
538
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
539
+ raise NotImplementedError(
540
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
541
+ " `device_map=None`."
542
+ )
543
+
544
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
545
+ raise NotImplementedError(
546
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
547
+ " `low_cpu_mem_usage=False`."
548
+ )
549
+
550
+ if low_cpu_mem_usage is False and device_map is not None:
551
+ raise ValueError(
552
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
553
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
554
+ )
555
+
556
+ user_agent = {
557
+ "diffusers": __version__,
558
+ "file_type": "model",
559
+ "framework": "pytorch",
560
+ }
561
+
562
+ # Load config if we don't provide a configuration
563
+ config_path = pretrained_model_name_or_path
564
+
565
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
566
+ # Load model
567
+
568
+ model_file = None
569
+ if is_safetensors_available():
570
+ try:
571
+ model_file = cls._get_model_file(
572
+ pretrained_model_name_or_path,
573
+ weights_name=SAFETENSORS_WEIGHTS_NAME,
574
+ cache_dir=cache_dir,
575
+ force_download=force_download,
576
+ resume_download=resume_download,
577
+ proxies=proxies,
578
+ local_files_only=local_files_only,
579
+ use_auth_token=use_auth_token,
580
+ revision=revision,
581
+ subfolder=subfolder,
582
+ user_agent=user_agent,
583
+ )
584
+ except:
585
+ pass
586
+ if model_file is None:
587
+ model_file = cls._get_model_file(
588
+ pretrained_model_name_or_path,
589
+ weights_name=WEIGHTS_NAME,
590
+ cache_dir=cache_dir,
591
+ force_download=force_download,
592
+ resume_download=resume_download,
593
+ proxies=proxies,
594
+ local_files_only=local_files_only,
595
+ use_auth_token=use_auth_token,
596
+ revision=revision,
597
+ subfolder=subfolder,
598
+ user_agent=user_agent,
599
+ )
600
+
601
+ if low_cpu_mem_usage:
602
+ # Instantiate model with empty weights
603
+ with accelerate.init_empty_weights():
604
+ config, unused_kwargs = cls.load_config(
605
+ config_path,
606
+ cache_dir=cache_dir,
607
+ return_unused_kwargs=True,
608
+ force_download=force_download,
609
+ resume_download=resume_download,
610
+ proxies=proxies,
611
+ local_files_only=local_files_only,
612
+ use_auth_token=use_auth_token,
613
+ revision=revision,
614
+ subfolder=subfolder,
615
+ device_map=device_map,
616
+ **kwargs,
617
+ )
618
+
619
+ # custom arg
620
+ config['use_sc_attn'] = use_sc_attn
621
+ config['use_st_attn'] = use_st_attn
622
+ config['st_attn_idx'] = st_attn_idx
623
+
624
+ model = cls.from_config(config, **unused_kwargs)
625
+
626
+ # if device_map is Non,e load the state dict on move the params from meta device to the cpu
627
+ if device_map is None:
628
+ param_device = "cpu"
629
+ state_dict = load_state_dict(model_file)
630
+ # move the parms from meta device to cpu
631
+ for param_name, param in state_dict.items():
632
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
633
+ else: # else let accelerate handle loading and dispatching.
634
+ # Load weights and dispatch according to the device_map
635
+ # by deafult the device_map is None and the weights are loaded on the CPU
636
+ accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
637
+
638
+ loading_info = {
639
+ "missing_keys": [],
640
+ "unexpected_keys": [],
641
+ "mismatched_keys": [],
642
+ "error_msgs": [],
643
+ }
644
+ else:
645
+ config, unused_kwargs = cls.load_config(
646
+ config_path,
647
+ cache_dir=cache_dir,
648
+ return_unused_kwargs=True,
649
+ force_download=force_download,
650
+ resume_download=resume_download,
651
+ proxies=proxies,
652
+ local_files_only=local_files_only,
653
+ use_auth_token=use_auth_token,
654
+ revision=revision,
655
+ subfolder=subfolder,
656
+ device_map=device_map,
657
+ **kwargs,
658
+ )
659
+
660
+ # custom arg
661
+ config['use_sc_attn'] = use_sc_attn
662
+ config['use_st_attn'] = use_st_attn
663
+ config['st_attn_idx'] = st_attn_idx
664
+
665
+ model = cls.from_config(config, **unused_kwargs)
666
+
667
+ state_dict = load_state_dict(model_file)
668
+ dtype = set(v.dtype for v in state_dict.values())
669
+
670
+ if len(dtype) > 1 and torch.float32 not in dtype:
671
+ raise ValueError(
672
+ f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
673
+ f" make sure that {model_file} weights have only one dtype."
674
+ )
675
+ elif len(dtype) > 1 and torch.float32 in dtype:
676
+ dtype = torch.float32
677
+ else:
678
+ dtype = dtype.pop()
679
+
680
+ # move model to correct dtype
681
+ model = model.to(dtype)
682
+
683
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
684
+ model,
685
+ state_dict,
686
+ model_file,
687
+ pretrained_model_name_or_path,
688
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
689
+ )
690
+
691
+ loading_info = {
692
+ "missing_keys": missing_keys,
693
+ "unexpected_keys": unexpected_keys,
694
+ "mismatched_keys": mismatched_keys,
695
+ "error_msgs": error_msgs,
696
+ }
697
+
698
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
699
+ raise ValueError(
700
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
701
+ )
702
+ elif torch_dtype is not None:
703
+ model = model.to(torch_dtype)
704
+
705
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
706
+
707
+ # Set model in evaluation mode to deactivate DropOut modules by default
708
+ model.eval()
709
+ if output_loading_info:
710
+ return model, loading_info
711
+
712
+ return model
vid2vid_zero/p2p/null_text_w_ptp.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Union, Tuple, List, Callable, Dict
17
+ from tqdm import tqdm
18
+ import torch
19
+ import torch.nn.functional as nnf
20
+ import numpy as np
21
+ import abc
22
+ from . import ptp_utils
23
+ from . import seq_aligner
24
+ import shutil
25
+ from torch.optim.adam import Adam
26
+ from PIL import Image
27
+
28
+
29
+ LOW_RESOURCE = False
30
+ NUM_DDIM_STEPS = 50
31
+ MAX_NUM_WORDS = 77
32
+ device = torch.device('cuda')
33
+ from transformers import CLIPTextModel, CLIPTokenizer
34
+
35
+ pretrained_model_path = "checkpoints/stable-diffusion-v1-4/"
36
+
37
+ ldm_stable = None
38
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
39
+
40
+
41
+ class LocalBlend:
42
+
43
+ def get_mask(self, maps, alpha, use_pool):
44
+ k = 1
45
+ maps = (maps * alpha).sum(-1).mean(1)
46
+ if use_pool:
47
+ maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
48
+ mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
49
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
50
+ mask = mask.gt(self.th[1-int(use_pool)])
51
+ mask = mask[:1] + mask
52
+ return mask
53
+
54
+ def __call__(self, x_t, attention_store):
55
+ self.counter += 1
56
+ if self.counter > self.start_blend:
57
+
58
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
59
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
60
+ maps = torch.cat(maps, dim=1)
61
+ mask = self.get_mask(maps, self.alpha_layers, True)
62
+ if self.substruct_layers is not None:
63
+ maps_sub = ~self.get_mask(maps, self.substruct_layers, False)
64
+ mask = mask * maps_sub
65
+ mask = mask.float()
66
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
67
+ return x_t
68
+
69
+ def __init__(self, prompts: List[str], words: List[List[str]], substruct_words=None, start_blend=0.2, th=(.3, .3)):
70
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
71
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
72
+ if type(words_) is str:
73
+ words_ = [words_]
74
+ for word in words_:
75
+ ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
76
+ alpha_layers[i, :, :, :, :, ind] = 1
77
+
78
+ if substruct_words is not None:
79
+ substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
80
+ for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
81
+ if type(words_) is str:
82
+ words_ = [words_]
83
+ for word in words_:
84
+ ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
85
+ substruct_layers[i, :, :, :, :, ind] = 1
86
+ self.substruct_layers = substruct_layers.to(device)
87
+ else:
88
+ self.substruct_layers = None
89
+ self.alpha_layers = alpha_layers.to(device)
90
+ self.start_blend = int(start_blend * NUM_DDIM_STEPS)
91
+ self.counter = 0
92
+ self.th=th
93
+
94
+
95
+ class EmptyControl:
96
+
97
+
98
+ def step_callback(self, x_t):
99
+ return x_t
100
+
101
+ def between_steps(self):
102
+ return
103
+
104
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
105
+ return attn
106
+
107
+
108
+ class AttentionControl(abc.ABC):
109
+
110
+ def step_callback(self, x_t):
111
+ return x_t
112
+
113
+ def between_steps(self):
114
+ return
115
+
116
+ @property
117
+ def num_uncond_att_layers(self):
118
+ return self.num_att_layers if LOW_RESOURCE else 0
119
+
120
+ @abc.abstractmethod
121
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
122
+ raise NotImplementedError
123
+
124
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
125
+ if self.cur_att_layer >= self.num_uncond_att_layers:
126
+ if LOW_RESOURCE:
127
+ attn = self.forward(attn, is_cross, place_in_unet)
128
+ else:
129
+ h = attn.shape[0]
130
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
131
+ self.cur_att_layer += 1
132
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
133
+ self.cur_att_layer = 0
134
+ self.cur_step += 1
135
+ self.between_steps()
136
+ return attn
137
+
138
+ def reset(self):
139
+ self.cur_step = 0
140
+ self.cur_att_layer = 0
141
+
142
+ def __init__(self):
143
+ self.cur_step = 0
144
+ self.num_att_layers = -1
145
+ self.cur_att_layer = 0
146
+
147
+
148
+ class SpatialReplace(EmptyControl):
149
+
150
+ def step_callback(self, x_t):
151
+ if self.cur_step < self.stop_inject:
152
+ b = x_t.shape[0]
153
+ x_t = x_t[:1].expand(b, *x_t.shape[1:])
154
+ return x_t
155
+
156
+ def __init__(self, stop_inject: float):
157
+ super(SpatialReplace, self).__init__()
158
+ self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
159
+
160
+
161
+ class AttentionStore(AttentionControl):
162
+
163
+ @staticmethod
164
+ def get_empty_store():
165
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
166
+ "down_self": [], "mid_self": [], "up_self": []}
167
+
168
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
169
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
170
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
171
+ self.step_store[key].append(attn)
172
+ return attn
173
+
174
+ def between_steps(self):
175
+ if len(self.attention_store) == 0:
176
+ self.attention_store = self.step_store
177
+ else:
178
+ for key in self.attention_store:
179
+ for i in range(len(self.attention_store[key])):
180
+ self.attention_store[key][i] += self.step_store[key][i]
181
+ self.step_store = self.get_empty_store()
182
+
183
+ def get_average_attention(self):
184
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
185
+ return average_attention
186
+
187
+
188
+ def reset(self):
189
+ super(AttentionStore, self).reset()
190
+ self.step_store = self.get_empty_store()
191
+ self.attention_store = {}
192
+
193
+ def __init__(self):
194
+ super(AttentionStore, self).__init__()
195
+ self.step_store = self.get_empty_store()
196
+ self.attention_store = {}
197
+
198
+
199
+ class AttentionControlEdit(AttentionStore, abc.ABC):
200
+
201
+ def step_callback(self, x_t):
202
+ if self.local_blend is not None:
203
+ x_t = self.local_blend(x_t, self.attention_store)
204
+ return x_t
205
+
206
+ def replace_self_attention(self, attn_base, att_replace, place_in_unet):
207
+ if att_replace.shape[2] <= 32 ** 2:
208
+ attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
209
+ return attn_base
210
+ else:
211
+ return att_replace
212
+
213
+ @abc.abstractmethod
214
+ def replace_cross_attention(self, attn_base, att_replace):
215
+ raise NotImplementedError
216
+
217
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
218
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
219
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
220
+ h = attn.shape[0] // (self.batch_size)
221
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
222
+ attn_base, attn_repalce = attn[0], attn[1:]
223
+ if is_cross:
224
+ alpha_words = self.cross_replace_alpha[self.cur_step]
225
+ attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
226
+ attn[1:] = attn_repalce_new
227
+ else:
228
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
229
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
230
+ return attn
231
+
232
+ def __init__(self, prompts, num_steps: int,
233
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
234
+ self_replace_steps: Union[float, Tuple[float, float]],
235
+ local_blend: Optional[LocalBlend]):
236
+ super(AttentionControlEdit, self).__init__()
237
+ self.batch_size = len(prompts)
238
+ self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
239
+ if type(self_replace_steps) is float:
240
+ self_replace_steps = 0, self_replace_steps
241
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
242
+ self.local_blend = local_blend
243
+
244
+ class AttentionReplace(AttentionControlEdit):
245
+
246
+ def replace_cross_attention(self, attn_base, att_replace):
247
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
248
+
249
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
250
+ local_blend: Optional[LocalBlend] = None):
251
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
252
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
253
+
254
+
255
+ class AttentionRefine(AttentionControlEdit):
256
+
257
+ def replace_cross_attention(self, attn_base, att_replace):
258
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
259
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
260
+ # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
261
+ return attn_replace
262
+
263
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
264
+ local_blend: Optional[LocalBlend] = None):
265
+ super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
266
+ self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
267
+ self.mapper, alphas = self.mapper.to(device), alphas.to(device)
268
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
269
+
270
+
271
+ class AttentionReweight(AttentionControlEdit):
272
+
273
+ def replace_cross_attention(self, attn_base, att_replace):
274
+ if self.prev_controller is not None:
275
+ attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
276
+ attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
277
+ # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
278
+ return attn_replace
279
+
280
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
281
+ local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
282
+ super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
283
+ self.equalizer = equalizer.to(device)
284
+ self.prev_controller = controller
285
+
286
+
287
+ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
288
+ Tuple[float, ...]]):
289
+ if type(word_select) is int or type(word_select) is str:
290
+ word_select = (word_select,)
291
+ equalizer = torch.ones(1, 77)
292
+
293
+ for word, val in zip(word_select, values):
294
+ inds = ptp_utils.get_word_inds(text, word, tokenizer)
295
+ equalizer[:, inds] = val
296
+ return equalizer
297
+
298
+ def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
299
+ out = []
300
+ attention_maps = attention_store.get_average_attention()
301
+ num_pixels = res ** 2
302
+ for location in from_where:
303
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
304
+ if item.shape[1] == num_pixels:
305
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
306
+ out.append(cross_maps)
307
+ out = torch.cat(out, dim=0)
308
+ out = out.sum(0) / out.shape[0]
309
+ return out.cpu()
310
+
311
+
312
+ def make_controller(prompts: List[str], is_replace_controller: bool, cross_replace_steps: Dict[str, float], self_replace_steps: float, blend_words=None, equilizer_params=None) -> AttentionControlEdit:
313
+ if blend_words is None:
314
+ lb = None
315
+ else:
316
+ lb = LocalBlend(prompts, blend_word)
317
+ if is_replace_controller:
318
+ controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
319
+ else:
320
+ controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, self_replace_steps=self_replace_steps, local_blend=lb)
321
+ if equilizer_params is not None:
322
+ eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"])
323
+ controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
324
+ self_replace_steps=self_replace_steps, equalizer=eq, local_blend=lb, controller=controller)
325
+ return controller
326
+
327
+
328
+ def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
329
+ tokens = tokenizer.encode(prompts[select])
330
+ decoder = tokenizer.decode
331
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
332
+ images = []
333
+ for i in range(len(tokens)):
334
+ image = attention_maps[:, :, i]
335
+ image = 255 * image / image.max()
336
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
337
+ image = image.numpy().astype(np.uint8)
338
+ image = np.array(Image.fromarray(image).resize((256, 256)))
339
+ image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
340
+ images.append(image)
341
+ ptp_utils.view_images(np.stack(images, axis=0))
342
+
343
+
344
+ class NullInversion:
345
+
346
+ def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
347
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
348
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
349
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
350
+ beta_prod_t = 1 - alpha_prod_t
351
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
352
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
353
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
354
+ return prev_sample
355
+
356
+ def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
357
+ timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
358
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
359
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
360
+ beta_prod_t = 1 - alpha_prod_t
361
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
362
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
363
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
364
+ return next_sample
365
+
366
+ def get_noise_pred_single(self, latents, t, context, normal_infer=True):
367
+ noise_pred = self.model.unet(latents, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"]
368
+ return noise_pred
369
+
370
+ def get_noise_pred(self, latents, t, is_forward=True, context=None, normal_infer=True):
371
+ latents_input = torch.cat([latents] * 2)
372
+ if context is None:
373
+ context = self.context
374
+ guidance_scale = 1 if is_forward else self.guidance_scale
375
+ noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"]
376
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
377
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
378
+ if is_forward:
379
+ latents = self.next_step(noise_pred, t, latents)
380
+ else:
381
+ latents = self.prev_step(noise_pred, t, latents)
382
+ return latents
383
+
384
+ @torch.no_grad()
385
+ def latent2image(self, latents, return_type='np'):
386
+ latents = 1 / 0.18215 * latents.detach()
387
+ image = self.model.vae.decode(latents)['sample']
388
+ if return_type == 'np':
389
+ image = (image / 2 + 0.5).clamp(0, 1)
390
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
391
+ image = (image * 255).astype(np.uint8)
392
+ return image
393
+
394
+ @torch.no_grad()
395
+ def image2latent(self, image):
396
+ with torch.no_grad():
397
+ if type(image) is Image:
398
+ image = np.array(image)
399
+ if type(image) is torch.Tensor and image.dim() == 4:
400
+ latents = image
401
+ else:
402
+ image = torch.from_numpy(image).float() / 127.5 - 1
403
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device)
404
+ latents = self.model.vae.encode(image)['latent_dist'].mean
405
+ latents = latents * 0.18215
406
+ return latents
407
+
408
+ @torch.no_grad()
409
+ def init_prompt(self, prompt: str):
410
+ uncond_input = self.model.tokenizer(
411
+ [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
412
+ return_tensors="pt"
413
+ )
414
+ uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
415
+ text_input = self.model.tokenizer(
416
+ [prompt],
417
+ padding="max_length",
418
+ max_length=self.model.tokenizer.model_max_length,
419
+ truncation=True,
420
+ return_tensors="pt",
421
+ )
422
+ # (1, 77, 768)
423
+ text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
424
+ # (2, 77, 768)
425
+ self.context = torch.cat([uncond_embeddings, text_embeddings])
426
+ self.prompt = prompt
427
+
428
+ @torch.no_grad()
429
+ def ddim_loop(self, latent):
430
+ uncond_embeddings, cond_embeddings = self.context.chunk(2)
431
+ cond = cond_embeddings if self.null_inv_with_prompt else uncond_embeddings
432
+ all_latent = [latent]
433
+ latent = latent.clone().detach()
434
+ for i in range(NUM_DDIM_STEPS):
435
+ t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
436
+ noise_pred = self.get_noise_pred_single(latent, t, cond, normal_infer=True)
437
+ latent = self.next_step(noise_pred, t, latent)
438
+ all_latent.append(latent)
439
+ return all_latent
440
+
441
+ @property
442
+ def scheduler(self):
443
+ return self.model.scheduler
444
+
445
+ @torch.no_grad()
446
+ def ddim_inversion(self, latent):
447
+ ddim_latents = self.ddim_loop(latent)
448
+ return ddim_latents
449
+
450
+ def null_optimization(self, latents, null_inner_steps, epsilon, null_base_lr=1e-2):
451
+ uncond_embeddings, cond_embeddings = self.context.chunk(2)
452
+ uncond_embeddings_list = []
453
+ latent_cur = latents[-1]
454
+ bar = tqdm(total=null_inner_steps * NUM_DDIM_STEPS)
455
+ for i in range(NUM_DDIM_STEPS):
456
+ uncond_embeddings = uncond_embeddings.clone().detach()
457
+ uncond_embeddings.requires_grad = True
458
+ optimizer = Adam([uncond_embeddings], lr=null_base_lr * (1. - i / 100.))
459
+ latent_prev = latents[len(latents) - i - 2]
460
+ t = self.model.scheduler.timesteps[i]
461
+ with torch.no_grad():
462
+ noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings, normal_infer=self.null_normal_infer)
463
+ for j in range(null_inner_steps):
464
+ noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings, normal_infer=self.null_normal_infer)
465
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
466
+ latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
467
+ loss = nnf.mse_loss(latents_prev_rec, latent_prev)
468
+ optimizer.zero_grad()
469
+ loss.backward()
470
+ optimizer.step()
471
+ assert not torch.isnan(uncond_embeddings.abs().mean())
472
+ loss_item = loss.item()
473
+ bar.update()
474
+ if loss_item < epsilon + i * 2e-5:
475
+ break
476
+ for j in range(j + 1, null_inner_steps):
477
+ bar.update()
478
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
479
+ with torch.no_grad():
480
+ context = torch.cat([uncond_embeddings, cond_embeddings])
481
+ latent_cur = self.get_noise_pred(latent_cur, t, False, context, normal_infer=self.null_normal_infer)
482
+ bar.close()
483
+ return uncond_embeddings_list
484
+
485
+ def invert(self, latents: torch.Tensor, prompt: str, null_inner_steps=10, early_stop_epsilon=1e-5, verbose=False, null_base_lr=1e-2):
486
+ self.init_prompt(prompt)
487
+ if verbose:
488
+ print("DDIM inversion...")
489
+ ddim_latents = self.ddim_inversion(latents.to(torch.float32))
490
+ if verbose:
491
+ print("Null-text optimization...")
492
+ uncond_embeddings = self.null_optimization(ddim_latents, null_inner_steps, early_stop_epsilon, null_base_lr=null_base_lr)
493
+ return ddim_latents[-1], uncond_embeddings
494
+
495
+
496
+ def __init__(self, model, guidance_scale, null_inv_with_prompt, null_normal_infer=True):
497
+ self.null_normal_infer = null_normal_infer
498
+ self.null_inv_with_prompt = null_inv_with_prompt
499
+ self.guidance_scale = guidance_scale
500
+ self.model = model
501
+ self.tokenizer = self.model.tokenizer
502
+ self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
503
+ self.prompt = None
504
+ self.context = None
vid2vid_zero/p2p/p2p_stable.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union, Tuple, List, Callable, Dict
16
+ import torch
17
+ import torch.nn.functional as nnf
18
+ import numpy as np
19
+ import abc
20
+ from . import ptp_utils
21
+ from . import seq_aligner
22
+ from transformers import CLIPTextModel, CLIPTokenizer
23
+
24
+ pretrained_model_path = "checkpoints/stable-diffusion-v1-4/"
25
+ ldm_stable = None
26
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
27
+
28
+ LOW_RESOURCE = False
29
+ NUM_DIFFUSION_STEPS = 50
30
+ GUIDANCE_SCALE = 7.5
31
+ MAX_NUM_WORDS = 77
32
+ device = torch.device('cuda')
33
+
34
+
35
+ class LocalBlend:
36
+
37
+ def __call__(self, x_t, attention_store):
38
+ k = 1
39
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
40
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
41
+ maps = torch.cat(maps, dim=1)
42
+ maps = (maps * self.alpha_layers).sum(-1).mean(1)
43
+ mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k))
44
+ mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
45
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
46
+ mask = mask.gt(self.threshold)
47
+ mask = (mask[:1] + mask[1:]).float()
48
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
49
+ return x_t
50
+
51
+ # def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3):
52
+ def __init__(self, prompts: List[str], words: List[List[str]], threshold=.3):
53
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
54
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
55
+ if type(words_) is str:
56
+ words_ = [words_]
57
+ for word in words_:
58
+ ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
59
+ alpha_layers[i, :, :, :, :, ind] = 1
60
+ self.alpha_layers = alpha_layers.to(device)
61
+ self.threshold = threshold
62
+
63
+
64
+ class AttentionControl(abc.ABC):
65
+
66
+ def step_callback(self, x_t):
67
+ return x_t
68
+
69
+ def between_steps(self):
70
+ return
71
+
72
+ @property
73
+ def num_uncond_att_layers(self):
74
+ return self.num_att_layers if LOW_RESOURCE else 0
75
+
76
+ @abc.abstractmethod
77
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
78
+ raise NotImplementedError
79
+
80
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
81
+ if self.cur_att_layer >= self.num_uncond_att_layers:
82
+ if LOW_RESOURCE:
83
+ attn = self.forward(attn, is_cross, place_in_unet)
84
+ else:
85
+ h = attn.shape[0]
86
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
87
+ self.cur_att_layer += 1
88
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
89
+ self.cur_att_layer = 0
90
+ self.cur_step += 1
91
+ self.between_steps()
92
+ return attn
93
+
94
+ def reset(self):
95
+ self.cur_step = 0
96
+ self.cur_att_layer = 0
97
+
98
+ def __init__(self):
99
+ self.cur_step = 0
100
+ self.num_att_layers = -1
101
+ self.cur_att_layer = 0
102
+
103
+ class EmptyControl(AttentionControl):
104
+
105
+ def forward (self, attn, is_cross: bool, place_in_unet: str):
106
+ return attn
107
+
108
+
109
+ class AttentionStore(AttentionControl):
110
+
111
+ @staticmethod
112
+ def get_empty_store():
113
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
114
+ "down_self": [], "mid_self": [], "up_self": []}
115
+
116
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
117
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
118
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
119
+ self.step_store[key].append(attn)
120
+ return attn
121
+
122
+ def between_steps(self):
123
+ if len(self.attention_store) == 0:
124
+ self.attention_store = self.step_store
125
+ else:
126
+ for key in self.attention_store:
127
+ for i in range(len(self.attention_store[key])):
128
+ self.attention_store[key][i] += self.step_store[key][i]
129
+ self.step_store = self.get_empty_store()
130
+
131
+ def get_average_attention(self):
132
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
133
+ return average_attention
134
+
135
+
136
+ def reset(self):
137
+ super(AttentionStore, self).reset()
138
+ self.step_store = self.get_empty_store()
139
+ self.attention_store = {}
140
+
141
+ def __init__(self):
142
+ super(AttentionStore, self).__init__()
143
+ self.step_store = self.get_empty_store()
144
+ self.attention_store = {}
145
+
146
+
147
+ class AttentionControlEdit(AttentionStore, abc.ABC):
148
+
149
+ def step_callback(self, x_t):
150
+ if self.local_blend is not None:
151
+ x_t = self.local_blend(x_t, self.attention_store)
152
+ return x_t
153
+
154
+ def replace_self_attention(self, attn_base, att_replace):
155
+ if att_replace.shape[2] <= 16 ** 2:
156
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
157
+ else:
158
+ return att_replace
159
+
160
+ @abc.abstractmethod
161
+ def replace_cross_attention(self, attn_base, att_replace):
162
+ raise NotImplementedError
163
+
164
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
165
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
166
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
167
+ h = attn.shape[0] // (self.batch_size)
168
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
169
+ attn_base, attn_repalce = attn[0], attn[1:]
170
+ if is_cross:
171
+ alpha_words = self.cross_replace_alpha[self.cur_step]
172
+ attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
173
+ attn[1:] = attn_repalce_new
174
+ else:
175
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
176
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
177
+ return attn
178
+
179
+ def __init__(self, prompts, num_steps: int,
180
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
181
+ self_replace_steps: Union[float, Tuple[float, float]],
182
+ local_blend: Optional[LocalBlend]):
183
+ super(AttentionControlEdit, self).__init__()
184
+ self.batch_size = len(prompts)
185
+ self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
186
+ if type(self_replace_steps) is float:
187
+ self_replace_steps = 0, self_replace_steps
188
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
189
+ self.local_blend = local_blend
190
+
191
+
192
+ class AttentionReplace(AttentionControlEdit):
193
+
194
+ def replace_cross_attention(self, attn_base, att_replace):
195
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
196
+
197
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
198
+ local_blend: Optional[LocalBlend] = None):
199
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
200
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
201
+
202
+
203
+ class AttentionRefine(AttentionControlEdit):
204
+
205
+ def replace_cross_attention(self, attn_base, att_replace):
206
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
207
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
208
+ return attn_replace
209
+
210
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
211
+ local_blend: Optional[LocalBlend] = None):
212
+ super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
213
+ self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
214
+ self.mapper, alphas = self.mapper.to(device), alphas.to(device)
215
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
216
+
217
+
218
+ class AttentionReweight(AttentionControlEdit):
219
+
220
+ def replace_cross_attention(self, attn_base, att_replace):
221
+ if self.prev_controller is not None:
222
+ attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
223
+ attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
224
+ return attn_replace
225
+
226
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
227
+ local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
228
+ super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
229
+ self.equalizer = equalizer.to(device)
230
+ self.prev_controller = controller
231
+
232
+
233
+ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],
234
+ Tuple[float, ...]]):
235
+ if type(word_select) is int or type(word_select) is str:
236
+ word_select = (word_select,)
237
+ equalizer = torch.ones(len(values), 77)
238
+ values = torch.tensor(values, dtype=torch.float32)
239
+ for word in word_select:
240
+ inds = ptp_utils.get_word_inds(text, word, tokenizer)
241
+ equalizer[:, inds] = values
242
+ return equalizer
vid2vid_zero/p2p/ptp_utils.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from PIL import Image, ImageDraw, ImageFont
18
+ import cv2
19
+ from typing import Optional, Union, Tuple, List, Callable, Dict
20
+ from IPython.display import display
21
+ from tqdm import tqdm
22
+ import torch.nn.functional as F
23
+
24
+
25
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
26
+ h, w, c = image.shape
27
+ offset = int(h * .2)
28
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
29
+ font = cv2.FONT_HERSHEY_SIMPLEX
30
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
31
+ img[:h] = image
32
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
33
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
34
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
35
+ return img
36
+
37
+
38
+ def view_images(images, num_rows=1, offset_ratio=0.02):
39
+ if type(images) is list:
40
+ num_empty = len(images) % num_rows
41
+ elif images.ndim == 4:
42
+ num_empty = images.shape[0] % num_rows
43
+ else:
44
+ images = [images]
45
+ num_empty = 0
46
+
47
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
48
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
49
+ num_items = len(images)
50
+
51
+ h, w, c = images[0].shape
52
+ offset = int(h * offset_ratio)
53
+ num_cols = num_items // num_rows
54
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
55
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
56
+ for i in range(num_rows):
57
+ for j in range(num_cols):
58
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
59
+ i * num_cols + j]
60
+
61
+ pil_img = Image.fromarray(image_)
62
+ display(pil_img)
63
+
64
+
65
+ def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
66
+ if low_resource:
67
+ noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
68
+ noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
69
+ else:
70
+ latents_input = torch.cat([latents] * 2)
71
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
72
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
73
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
74
+ latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
75
+ latents = controller.step_callback(latents)
76
+ return latents
77
+
78
+
79
+ def latent2image(vae, latents):
80
+ latents = 1 / 0.18215 * latents
81
+ image = vae.decode(latents)['sample']
82
+ image = (image / 2 + 0.5).clamp(0, 1)
83
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
84
+ image = (image * 255).astype(np.uint8)
85
+ return image
86
+
87
+
88
+ def init_latent(latent, model, height, width, generator, batch_size):
89
+ if latent is None:
90
+ latent = torch.randn(
91
+ (1, model.unet.in_channels, height // 8, width // 8),
92
+ generator=generator,
93
+ )
94
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
95
+ return latent, latents
96
+
97
+
98
+ @torch.no_grad()
99
+ def text2image_ldm(
100
+ model,
101
+ prompt: List[str],
102
+ controller,
103
+ num_inference_steps: int = 50,
104
+ guidance_scale: Optional[float] = 7.,
105
+ generator: Optional[torch.Generator] = None,
106
+ latent: Optional[torch.FloatTensor] = None,
107
+ ):
108
+ register_attention_control(model, controller)
109
+ height = width = 256
110
+ batch_size = len(prompt)
111
+
112
+ uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
113
+ uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
114
+
115
+ text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
116
+ text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
117
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
118
+ context = torch.cat([uncond_embeddings, text_embeddings])
119
+
120
+ model.scheduler.set_timesteps(num_inference_steps)
121
+ for t in tqdm(model.scheduler.timesteps):
122
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
123
+
124
+ image = latent2image(model.vqvae, latents)
125
+
126
+ return image, latent
127
+
128
+
129
+ @torch.no_grad()
130
+ def text2image_ldm_stable(
131
+ model,
132
+ prompt: List[str],
133
+ controller,
134
+ num_inference_steps: int = 50,
135
+ guidance_scale: float = 7.5,
136
+ generator: Optional[torch.Generator] = None,
137
+ latent: Optional[torch.FloatTensor] = None,
138
+ low_resource: bool = False,
139
+ ):
140
+ register_attention_control(model, controller)
141
+ height = width = 512
142
+ batch_size = len(prompt)
143
+
144
+ text_input = model.tokenizer(
145
+ prompt,
146
+ padding="max_length",
147
+ max_length=model.tokenizer.model_max_length,
148
+ truncation=True,
149
+ return_tensors="pt",
150
+ )
151
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
152
+ max_length = text_input.input_ids.shape[-1]
153
+ uncond_input = model.tokenizer(
154
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
155
+ )
156
+ uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
157
+
158
+ context = [uncond_embeddings, text_embeddings]
159
+ if not low_resource:
160
+ context = torch.cat(context)
161
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
162
+
163
+ # set timesteps
164
+ extra_set_kwargs = {"offset": 1}
165
+ model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
166
+ for t in tqdm(model.scheduler.timesteps):
167
+ latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
168
+
169
+ image = latent2image(model.vae, latents)
170
+
171
+ return image, latent
172
+
173
+
174
+ def register_attention_control(model, controller):
175
+
176
+ def ca_forward(self, place_in_unet):
177
+ def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
178
+ batch_size, sequence_length, _ = hidden_states.shape
179
+
180
+ is_cross = encoder_hidden_states is not None
181
+ encoder_hidden_states = encoder_hidden_states
182
+
183
+ if self.group_norm is not None:
184
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
185
+
186
+ query = self.to_q(hidden_states)
187
+ # dim = query.shape[-1]
188
+ query = self.reshape_heads_to_batch_dim(query)
189
+
190
+ if self.added_kv_proj_dim is not None:
191
+ key = self.to_k(hidden_states)
192
+ value = self.to_v(hidden_states)
193
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
194
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
195
+
196
+ key = self.reshape_heads_to_batch_dim(key)
197
+ value = self.reshape_heads_to_batch_dim(value)
198
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
199
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
200
+
201
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
202
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
203
+ else:
204
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
205
+ key = self.to_k(encoder_hidden_states)
206
+ value = self.to_v(encoder_hidden_states)
207
+
208
+ key = self.reshape_heads_to_batch_dim(key)
209
+ value = self.reshape_heads_to_batch_dim(value)
210
+
211
+ if attention_mask is not None:
212
+ if attention_mask.shape[-1] != query.shape[1]:
213
+ target_length = query.shape[1]
214
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
215
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
216
+
217
+ assert self._slice_size is None or query.shape[0] // self._slice_size == 1
218
+
219
+ if self.upcast_attention:
220
+ query = query.float()
221
+ key = key.float()
222
+
223
+ attention_scores = torch.baddbmm(
224
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
225
+ query,
226
+ key.transpose(-1, -2),
227
+ beta=0,
228
+ alpha=self.scale,
229
+ )
230
+
231
+ if attention_mask is not None:
232
+ attention_scores = attention_scores + attention_mask
233
+
234
+ if self.upcast_softmax:
235
+ attention_scores = attention_scores.float()
236
+
237
+ attention_probs = attention_scores.softmax(dim=-1)
238
+
239
+ # attn control
240
+ attention_probs = controller(attention_probs, is_cross, place_in_unet)
241
+
242
+ # cast back to the original dtype
243
+ attention_probs = attention_probs.to(value.dtype)
244
+
245
+ # compute attention output
246
+ hidden_states = torch.bmm(attention_probs, value)
247
+
248
+ # reshape hidden_states
249
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
250
+
251
+ # linear proj
252
+ hidden_states = self.to_out[0](hidden_states)
253
+
254
+ # dropout
255
+ hidden_states = self.to_out[1](hidden_states)
256
+ return hidden_states
257
+
258
+ return forward
259
+
260
+ class DummyController:
261
+
262
+ def __call__(self, *args):
263
+ return args[0]
264
+
265
+ def __init__(self):
266
+ self.num_att_layers = 0
267
+
268
+ if controller is None:
269
+ controller = DummyController()
270
+
271
+ def register_recr(net_, count, place_in_unet):
272
+ if net_.__class__.__name__ == 'CrossAttention':
273
+ net_.forward = ca_forward(net_, place_in_unet)
274
+ return count + 1
275
+ elif hasattr(net_, 'children'):
276
+ for net__ in net_.children():
277
+ count = register_recr(net__, count, place_in_unet)
278
+ return count
279
+
280
+ cross_att_count = 0
281
+ # sub_nets = model.unet.named_children()
282
+ # we take unet as the input model
283
+ sub_nets = model.named_children()
284
+ for net in sub_nets:
285
+ if "down" in net[0]:
286
+ cross_att_count += register_recr(net[1], 0, "down")
287
+ elif "up" in net[0]:
288
+ cross_att_count += register_recr(net[1], 0, "up")
289
+ elif "mid" in net[0]:
290
+ cross_att_count += register_recr(net[1], 0, "mid")
291
+
292
+ controller.num_att_layers = cross_att_count
293
+
294
+
295
+ def get_word_inds(text: str, word_place: int, tokenizer):
296
+ split_text = text.split(" ")
297
+ if type(word_place) is str:
298
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
299
+ elif type(word_place) is int:
300
+ word_place = [word_place]
301
+ out = []
302
+ if len(word_place) > 0:
303
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
304
+ cur_len, ptr = 0, 0
305
+
306
+ for i in range(len(words_encode)):
307
+ cur_len += len(words_encode[i])
308
+ if ptr in word_place:
309
+ out.append(i + 1)
310
+ if cur_len >= len(split_text[ptr]):
311
+ ptr += 1
312
+ cur_len = 0
313
+ return np.array(out)
314
+
315
+
316
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
317
+ word_inds: Optional[torch.Tensor]=None):
318
+ if type(bounds) is float:
319
+ bounds = 0, bounds
320
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
321
+ if word_inds is None:
322
+ word_inds = torch.arange(alpha.shape[2])
323
+ alpha[: start, prompt_ind, word_inds] = 0
324
+ alpha[start: end, prompt_ind, word_inds] = 1
325
+ alpha[end:, prompt_ind, word_inds] = 0
326
+ return alpha
327
+
328
+
329
+ def get_time_words_attention_alpha(prompts, num_steps,
330
+ cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
331
+ tokenizer, max_num_words=77):
332
+ if type(cross_replace_steps) is not dict:
333
+ cross_replace_steps = {"default_": cross_replace_steps}
334
+ if "default_" not in cross_replace_steps:
335
+ cross_replace_steps["default_"] = (0., 1.)
336
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
337
+ for i in range(len(prompts) - 1):
338
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
339
+ i)
340
+ for key, item in cross_replace_steps.items():
341
+ if key != "default_":
342
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
343
+ for i, ind in enumerate(inds):
344
+ if len(ind) > 0:
345
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
346
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
347
+ return alpha_time_words
vid2vid_zero/p2p/seq_aligner.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+
18
+
19
+ class ScoreParams:
20
+
21
+ def __init__(self, gap, match, mismatch):
22
+ self.gap = gap
23
+ self.match = match
24
+ self.mismatch = mismatch
25
+
26
+ def mis_match_char(self, x, y):
27
+ if x != y:
28
+ return self.mismatch
29
+ else:
30
+ return self.match
31
+
32
+
33
+ def get_matrix(size_x, size_y, gap):
34
+ matrix = []
35
+ for i in range(len(size_x) + 1):
36
+ sub_matrix = []
37
+ for j in range(len(size_y) + 1):
38
+ sub_matrix.append(0)
39
+ matrix.append(sub_matrix)
40
+ for j in range(1, len(size_y) + 1):
41
+ matrix[0][j] = j*gap
42
+ for i in range(1, len(size_x) + 1):
43
+ matrix[i][0] = i*gap
44
+ return matrix
45
+
46
+
47
+ def get_matrix(size_x, size_y, gap):
48
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
49
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
50
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
51
+ return matrix
52
+
53
+
54
+ def get_traceback_matrix(size_x, size_y):
55
+ matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
56
+ matrix[0, 1:] = 1
57
+ matrix[1:, 0] = 2
58
+ matrix[0, 0] = 4
59
+ return matrix
60
+
61
+
62
+ def global_align(x, y, score):
63
+ matrix = get_matrix(len(x), len(y), score.gap)
64
+ trace_back = get_traceback_matrix(len(x), len(y))
65
+ for i in range(1, len(x) + 1):
66
+ for j in range(1, len(y) + 1):
67
+ left = matrix[i, j - 1] + score.gap
68
+ up = matrix[i - 1, j] + score.gap
69
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
70
+ matrix[i, j] = max(left, up, diag)
71
+ if matrix[i, j] == left:
72
+ trace_back[i, j] = 1
73
+ elif matrix[i, j] == up:
74
+ trace_back[i, j] = 2
75
+ else:
76
+ trace_back[i, j] = 3
77
+ return matrix, trace_back
78
+
79
+
80
+ def get_aligned_sequences(x, y, trace_back):
81
+ x_seq = []
82
+ y_seq = []
83
+ i = len(x)
84
+ j = len(y)
85
+ mapper_y_to_x = []
86
+ while i > 0 or j > 0:
87
+ if trace_back[i, j] == 3:
88
+ x_seq.append(x[i-1])
89
+ y_seq.append(y[j-1])
90
+ i = i-1
91
+ j = j-1
92
+ mapper_y_to_x.append((j, i))
93
+ elif trace_back[i][j] == 1:
94
+ x_seq.append('-')
95
+ y_seq.append(y[j-1])
96
+ j = j-1
97
+ mapper_y_to_x.append((j, -1))
98
+ elif trace_back[i][j] == 2:
99
+ x_seq.append(x[i-1])
100
+ y_seq.append('-')
101
+ i = i-1
102
+ elif trace_back[i][j] == 4:
103
+ break
104
+ mapper_y_to_x.reverse()
105
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
106
+
107
+
108
+ def get_mapper(x: str, y: str, tokenizer, max_len=77):
109
+ x_seq = tokenizer.encode(x)
110
+ y_seq = tokenizer.encode(y)
111
+ score = ScoreParams(0, 1, -1)
112
+ matrix, trace_back = global_align(x_seq, y_seq, score)
113
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
114
+ alphas = torch.ones(max_len)
115
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
116
+ mapper = torch.zeros(max_len, dtype=torch.int64)
117
+ mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
118
+ mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
119
+ return mapper, alphas
120
+
121
+
122
+ def get_refinement_mapper(prompts, tokenizer, max_len=77):
123
+ x_seq = prompts[0]
124
+ mappers, alphas = [], []
125
+ for i in range(1, len(prompts)):
126
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
127
+ mappers.append(mapper)
128
+ alphas.append(alpha)
129
+ return torch.stack(mappers), torch.stack(alphas)
130
+
131
+
132
+ def get_word_inds(text: str, word_place: int, tokenizer):
133
+ split_text = text.split(" ")
134
+ if type(word_place) is str:
135
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
136
+ elif type(word_place) is int:
137
+ word_place = [word_place]
138
+ out = []
139
+ if len(word_place) > 0:
140
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
141
+ cur_len, ptr = 0, 0
142
+
143
+ for i in range(len(words_encode)):
144
+ cur_len += len(words_encode[i])
145
+ if ptr in word_place:
146
+ out.append(i + 1)
147
+ if cur_len >= len(split_text[ptr]):
148
+ ptr += 1
149
+ cur_len = 0
150
+ return np.array(out)
151
+
152
+
153
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
154
+ words_x = x.split(' ')
155
+ words_y = y.split(' ')
156
+ if len(words_x) != len(words_y):
157
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
158
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
159
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
160
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
161
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
162
+ mapper = np.zeros((max_len, max_len))
163
+ i = j = 0
164
+ cur_inds = 0
165
+ while i < max_len and j < max_len:
166
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
167
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
168
+ if len(inds_source_) == len(inds_target_):
169
+ mapper[inds_source_, inds_target_] = 1
170
+ else:
171
+ ratio = 1 / len(inds_target_)
172
+ for i_t in inds_target_:
173
+ mapper[inds_source_, i_t] = ratio
174
+ cur_inds += 1
175
+ i += len(inds_source_)
176
+ j += len(inds_target_)
177
+ elif cur_inds < len(inds_source):
178
+ mapper[i, j] = 1
179
+ i += 1
180
+ j += 1
181
+ else:
182
+ mapper[j, j] = 1
183
+ i += 1
184
+ j += 1
185
+
186
+ return torch.from_numpy(mapper).float()
187
+
188
+
189
+
190
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
191
+ x_seq = prompts[0]
192
+ mappers = []
193
+ for i in range(1, len(prompts)):
194
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
195
+ mappers.append(mapper)
196
+ return torch.stack(mappers)
197
+
vid2vid_zero/pipelines/pipeline_vid2vid_zero.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, List, Optional, Union
17
+ from dataclasses import dataclass
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers.utils import is_accelerate_available
23
+ from packaging import version
24
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
25
+
26
+ from diffusers.configuration_utils import FrozenDict
27
+ from diffusers.models import AutoencoderKL # UNet2DConditionModel
28
+ from diffusers.pipeline_utils import DiffusionPipeline
29
+ from diffusers.schedulers import (
30
+ DDIMScheduler,
31
+ DPMSolverMultistepScheduler,
32
+ EulerAncestralDiscreteScheduler,
33
+ EulerDiscreteScheduler,
34
+ LMSDiscreteScheduler,
35
+ PNDMScheduler,
36
+ )
37
+ from diffusers.utils import deprecate, logging, BaseOutput
38
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
39
+
40
+ from einops import rearrange
41
+
42
+ from ..models.unet_2d_condition import UNet2DConditionModel
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ @dataclass
49
+ class Vid2VidZeroPipelineOutput(BaseOutput):
50
+ images: Union[torch.Tensor, np.ndarray]
51
+
52
+
53
+ class Vid2VidZeroPipeline(DiffusionPipeline):
54
+ r"""
55
+ Pipeline for text-to-image generation using Stable Diffusion.
56
+
57
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
58
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
59
+
60
+ Args:
61
+ vae ([`AutoencoderKL`]):
62
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
63
+ text_encoder ([`CLIPTextModel`]):
64
+ Frozen text-encoder. Stable Diffusion uses the text portion of
65
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
66
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
67
+ tokenizer (`CLIPTokenizer`):
68
+ Tokenizer of class
69
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
70
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
71
+ scheduler ([`SchedulerMixin`]):
72
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
73
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
74
+ safety_checker ([`StableDiffusionSafetyChecker`]):
75
+ Classification module that estimates whether generated images could be considered offensive or harmful.
76
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
77
+ feature_extractor ([`CLIPFeatureExtractor`]):
78
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
79
+ """
80
+ _optional_components = ["safety_checker", "feature_extractor"]
81
+
82
+ def __init__(
83
+ self,
84
+ vae: AutoencoderKL,
85
+ text_encoder: CLIPTextModel,
86
+ tokenizer: CLIPTokenizer,
87
+ unet: UNet2DConditionModel,
88
+ scheduler: Union[
89
+ DDIMScheduler,
90
+ PNDMScheduler,
91
+ LMSDiscreteScheduler,
92
+ EulerDiscreteScheduler,
93
+ EulerAncestralDiscreteScheduler,
94
+ DPMSolverMultistepScheduler,
95
+ ],
96
+ safety_checker: StableDiffusionSafetyChecker,
97
+ feature_extractor: CLIPFeatureExtractor,
98
+ requires_safety_checker: bool = False,
99
+ ):
100
+ super().__init__()
101
+
102
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
103
+ deprecation_message = (
104
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
105
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
106
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
107
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
108
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
109
+ " file"
110
+ )
111
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
112
+ new_config = dict(scheduler.config)
113
+ new_config["steps_offset"] = 1
114
+ scheduler._internal_dict = FrozenDict(new_config)
115
+
116
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
117
+ deprecation_message = (
118
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
119
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
120
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
121
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
122
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
123
+ )
124
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
125
+ new_config = dict(scheduler.config)
126
+ new_config["clip_sample"] = False
127
+ scheduler._internal_dict = FrozenDict(new_config)
128
+
129
+ if safety_checker is None and requires_safety_checker:
130
+ logger.warning(
131
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
132
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
133
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
134
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
135
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
136
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
137
+ )
138
+
139
+ if safety_checker is not None and feature_extractor is None:
140
+ raise ValueError(
141
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
142
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
143
+ )
144
+
145
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
146
+ version.parse(unet.config._diffusers_version).base_version
147
+ ) < version.parse("0.9.0.dev0")
148
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
149
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
150
+ deprecation_message = (
151
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
152
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
153
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
154
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
155
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
156
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
157
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
158
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
159
+ " the `unet/config.json` file"
160
+ )
161
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
162
+ new_config = dict(unet.config)
163
+ new_config["sample_size"] = 64
164
+ unet._internal_dict = FrozenDict(new_config)
165
+
166
+ self.register_modules(
167
+ vae=vae,
168
+ text_encoder=text_encoder,
169
+ tokenizer=tokenizer,
170
+ unet=unet,
171
+ scheduler=scheduler,
172
+ safety_checker=safety_checker,
173
+ feature_extractor=feature_extractor,
174
+ )
175
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
176
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
177
+
178
+ def enable_vae_slicing(self):
179
+ r"""
180
+ Enable sliced VAE decoding.
181
+
182
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
183
+ steps. This is useful to save some memory and allow larger batch sizes.
184
+ """
185
+ self.vae.enable_slicing()
186
+
187
+ def disable_vae_slicing(self):
188
+ r"""
189
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
190
+ computing decoding in one step.
191
+ """
192
+ self.vae.disable_slicing()
193
+
194
+ def enable_sequential_cpu_offload(self, gpu_id=0):
195
+ r"""
196
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
197
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
198
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
199
+ """
200
+ if is_accelerate_available():
201
+ from accelerate import cpu_offload
202
+ else:
203
+ raise ImportError("Please install accelerate via `pip install accelerate`")
204
+
205
+ device = torch.device(f"cuda:{gpu_id}")
206
+
207
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
208
+ if cpu_offloaded_model is not None:
209
+ cpu_offload(cpu_offloaded_model, device)
210
+
211
+ if self.safety_checker is not None:
212
+ # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate
213
+ # fix by only offloading self.safety_checker for now
214
+ cpu_offload(self.safety_checker.vision_model, device)
215
+
216
+ @property
217
+ def _execution_device(self):
218
+ r"""
219
+ Returns the device on which the pipeline's models will be executed. After calling
220
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
221
+ hooks.
222
+ """
223
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
224
+ return self.device
225
+ for module in self.unet.modules():
226
+ if (
227
+ hasattr(module, "_hf_hook")
228
+ and hasattr(module._hf_hook, "execution_device")
229
+ and module._hf_hook.execution_device is not None
230
+ ):
231
+ return torch.device(module._hf_hook.execution_device)
232
+ return self.device
233
+
234
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, uncond_embeddings=None):
235
+ r"""
236
+ Encodes the prompt into text encoder hidden states.
237
+
238
+ Args:
239
+ prompt (`str` or `list(int)`):
240
+ prompt to be encoded
241
+ device: (`torch.device`):
242
+ torch device
243
+ num_images_per_prompt (`int`):
244
+ number of images that should be generated per prompt
245
+ do_classifier_free_guidance (`bool`):
246
+ whether to use classifier free guidance or not
247
+ negative_prompt (`str` or `List[str]`):
248
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
249
+ if `guidance_scale` is less than `1`).
250
+ """
251
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
252
+
253
+ text_inputs = self.tokenizer(
254
+ prompt,
255
+ padding="max_length",
256
+ max_length=self.tokenizer.model_max_length,
257
+ truncation=True,
258
+ return_tensors="pt",
259
+ )
260
+ text_input_ids = text_inputs.input_ids
261
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
262
+
263
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
264
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
265
+ logger.warning(
266
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
267
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
268
+ )
269
+
270
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
271
+ attention_mask = text_inputs.attention_mask.to(device)
272
+ else:
273
+ attention_mask = None
274
+
275
+ text_embeddings = self.text_encoder(
276
+ text_input_ids.to(device),
277
+ attention_mask=attention_mask,
278
+ )
279
+ text_embeddings = text_embeddings[0]
280
+
281
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
282
+ # num_videos_per_prompt = 1, thus nothing happens here
283
+ bs_embed, seq_len, _ = text_embeddings.shape
284
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
285
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
286
+
287
+ # get unconditional embeddings for classifier free guidance
288
+ if do_classifier_free_guidance:
289
+ uncond_tokens: List[str]
290
+ if negative_prompt is None:
291
+ uncond_tokens = [""] * batch_size
292
+ elif type(prompt) is not type(negative_prompt):
293
+ raise TypeError(
294
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
295
+ f" {type(prompt)}."
296
+ )
297
+ elif isinstance(negative_prompt, str):
298
+ uncond_tokens = [negative_prompt]
299
+ elif batch_size != len(negative_prompt):
300
+ raise ValueError(
301
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
302
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
303
+ " the batch size of `prompt`."
304
+ )
305
+ else:
306
+ uncond_tokens = negative_prompt
307
+
308
+ max_length = text_input_ids.shape[-1]
309
+ uncond_input = self.tokenizer(
310
+ uncond_tokens,
311
+ padding="max_length",
312
+ max_length=max_length,
313
+ truncation=True,
314
+ return_tensors="pt",
315
+ )
316
+
317
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
318
+ attention_mask = uncond_input.attention_mask.to(device)
319
+ else:
320
+ attention_mask = None
321
+
322
+ uncond_embeddings = self.text_encoder(
323
+ uncond_input.input_ids.to(device),
324
+ attention_mask=attention_mask,
325
+ )
326
+ uncond_embeddings = uncond_embeddings[0]
327
+
328
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
329
+ seq_len = uncond_embeddings.shape[1]
330
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
331
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
332
+
333
+ # For classifier free guidance, we need to do two forward passes.
334
+ # Here we concatenate the unconditional and text embeddings into a single batch
335
+ # to avoid doing two forward passes
336
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
337
+
338
+ return text_embeddings
339
+
340
+ def run_safety_checker(self, image, device, dtype):
341
+ if self.safety_checker is not None:
342
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
343
+ image, has_nsfw_concept = self.safety_checker(
344
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
345
+ )
346
+ else:
347
+ has_nsfw_concept = None
348
+ return image, has_nsfw_concept
349
+
350
+ def decode_latents(self, latents):
351
+ video_length = latents.shape[2]
352
+ latents = 1 / 0.18215 * latents
353
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
354
+ video = self.vae.decode(latents).sample
355
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
356
+ video = (video / 2 + 0.5).clamp(0, 1)
357
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
358
+ video = video.cpu().float().numpy()
359
+ return video
360
+
361
+ def prepare_extra_step_kwargs(self, generator, eta):
362
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
363
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
364
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
365
+ # and should be between [0, 1]
366
+
367
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
368
+ extra_step_kwargs = {}
369
+ if accepts_eta:
370
+ extra_step_kwargs["eta"] = eta
371
+
372
+ # check if the scheduler accepts generator
373
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
374
+ if accepts_generator:
375
+ extra_step_kwargs["generator"] = generator
376
+ return extra_step_kwargs
377
+
378
+ def check_inputs(self, prompt, height, width, callback_steps):
379
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
380
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
381
+
382
+ if height % 8 != 0 or width % 8 != 0:
383
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
384
+
385
+ if (callback_steps is None) or (
386
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
387
+ ):
388
+ raise ValueError(
389
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
390
+ f" {type(callback_steps)}."
391
+ )
392
+
393
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
394
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
395
+ if isinstance(generator, list) and len(generator) != batch_size:
396
+ raise ValueError(
397
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
398
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
399
+ )
400
+
401
+ if latents is None:
402
+ rand_device = "cpu" if device.type == "mps" else device
403
+
404
+ if isinstance(generator, list):
405
+ shape = (1,) + shape[1:]
406
+ latents = [
407
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
408
+ for i in range(batch_size)
409
+ ]
410
+ latents = torch.cat(latents, dim=0).to(device)
411
+ else:
412
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
413
+ else:
414
+ if latents.shape != shape:
415
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
416
+ latents = latents.to(device)
417
+
418
+ # scale the initial noise by the standard deviation required by the scheduler
419
+ latents = latents * self.scheduler.init_noise_sigma
420
+ return latents
421
+
422
+ @torch.no_grad()
423
+ def __call__(
424
+ self,
425
+ prompt: Union[str, List[str]],
426
+ video_length: Optional[int],
427
+ height: Optional[int] = None,
428
+ width: Optional[int] = None,
429
+ num_inference_steps: int = 50,
430
+ guidance_scale: float = 7.5,
431
+ negative_prompt: Optional[Union[str, List[str]]] = None,
432
+ num_videos_per_prompt: Optional[int] = 1,
433
+ eta: float = 0.0,
434
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
435
+ latents: Optional[torch.FloatTensor] = None,
436
+ output_type: Optional[str] = "tensor",
437
+ return_dict: bool = True,
438
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
439
+ callback_steps: Optional[int] = 1,
440
+ uncond_embeddings: torch.Tensor = None,
441
+ null_uncond_ratio: float = 1.0,
442
+ **kwargs,
443
+ ):
444
+ # Default height and width to unet
445
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
446
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
447
+
448
+ # Check inputs. Raise error if not correct
449
+ self.check_inputs(prompt, height, width, callback_steps)
450
+
451
+ # Define call parameters
452
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
453
+ device = self._execution_device
454
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
455
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
456
+ # corresponds to doing no classifier free guidance.
457
+ do_classifier_free_guidance = guidance_scale > 1.0
458
+
459
+ # Encode input prompt
460
+ with_uncond_embedding = do_classifier_free_guidance if uncond_embeddings is None else False
461
+ text_embeddings = self._encode_prompt(
462
+ prompt, device, num_videos_per_prompt, with_uncond_embedding, negative_prompt,
463
+ )
464
+
465
+ # Prepare timesteps
466
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
467
+ timesteps = self.scheduler.timesteps
468
+
469
+ # Prepare latent variables
470
+ num_channels_latents = self.unet.in_channels
471
+ latents = self.prepare_latents(
472
+ batch_size * num_videos_per_prompt,
473
+ num_channels_latents,
474
+ video_length,
475
+ height,
476
+ width,
477
+ text_embeddings.dtype,
478
+ device,
479
+ generator,
480
+ latents,
481
+ )
482
+ latents_dtype = latents.dtype
483
+
484
+ # Prepare extra step kwargs.
485
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
486
+
487
+ # Denoising loop
488
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
489
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
490
+ if uncond_embeddings is not None:
491
+ start_time = 50
492
+ assert (timesteps[-start_time:] == timesteps).all()
493
+ for i, t in enumerate(timesteps):
494
+ # expand the latents if we are doing classifier free guidance
495
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
496
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
497
+
498
+ if uncond_embeddings is not None:
499
+ use_uncond_this_step = True
500
+ if null_uncond_ratio > 0:
501
+ if i > len(timesteps) * null_uncond_ratio:
502
+ use_uncond_this_step = False
503
+ else:
504
+ if i < len(timesteps) * (1 + null_uncond_ratio):
505
+ use_uncond_this_step = False
506
+ if use_uncond_this_step:
507
+ text_embeddings_input = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
508
+ else:
509
+ uncond_embeddings_ = self._encode_prompt('', device, num_videos_per_prompt, False, negative_prompt)
510
+ text_embeddings_input = torch.cat([uncond_embeddings_.expand(*text_embeddings.shape), text_embeddings])
511
+ else:
512
+ text_embeddings_input = text_embeddings
513
+
514
+ # predict the noise residual
515
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings_input).sample.to(dtype=latents_dtype)
516
+
517
+ # perform guidance
518
+ if do_classifier_free_guidance:
519
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
520
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
521
+
522
+ # compute the previous noisy sample x_t -> x_t-1
523
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
524
+
525
+ # call the callback, if provided
526
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
527
+ progress_bar.update()
528
+ if callback is not None and i % callback_steps == 0:
529
+ callback(i, t, latents)
530
+
531
+ # Post-processing
532
+ images = self.decode_latents(latents)
533
+
534
+ # Convert to tensor
535
+ if output_type == "tensor":
536
+ images = torch.from_numpy(images)
537
+
538
+ if not return_dict:
539
+ return images
540
+
541
+ return Vid2VidZeroPipelineOutput(images=images)
vid2vid_zero/util.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import tempfile
4
+ import numpy as np
5
+ from PIL import Image
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torchvision
10
+
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+
14
+
15
+ def save_videos_as_images(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=1):
16
+ dir_name = os.path.dirname(path)
17
+ videos = rearrange(videos, "b c t h w -> t b h w c")
18
+
19
+ os.makedirs(os.path.join(dir_name, "vis_images"), exist_ok=True)
20
+ for frame_idx, x in enumerate(videos):
21
+ if rescale:
22
+ x = (x + 1.0) / 2.0
23
+ x = (x * 255).numpy().astype(np.uint8)
24
+
25
+ for batch_idx, image in enumerate(x):
26
+ save_dir = os.path.join(dir_name, "vis_images", f"batch_{batch_idx}")
27
+ os.makedirs(save_dir, exist_ok=True)
28
+ save_path = os.path.join(save_dir, f"frame_{frame_idx}.png")
29
+ image = Image.fromarray(image)
30
+ image.save(save_path)
31
+
32
+
33
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=1):
34
+ videos = rearrange(videos, "b c t h w -> t b c h w")
35
+ outputs = []
36
+ for x in videos:
37
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
38
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
39
+ if rescale:
40
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
41
+ x = (x * 255).numpy().astype(np.uint8)
42
+ outputs.append(x)
43
+
44
+ os.makedirs(os.path.dirname(path), exist_ok=True)
45
+ imageio.mimsave(path, outputs, fps=fps)
46
+
47
+ # save for gradio demo
48
+ out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
49
+ out_file.name = path.replace('.gif', '.mp4')
50
+ writer = imageio.get_writer(out_file.name, fps=fps)
51
+ for frame in outputs:
52
+ writer.append_data(frame)
53
+ writer.close()
54
+
55
+
56
+ @torch.no_grad()
57
+ def init_prompt(prompt, pipeline):
58
+ uncond_input = pipeline.tokenizer(
59
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
60
+ return_tensors="pt"
61
+ )
62
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
63
+ text_input = pipeline.tokenizer(
64
+ [prompt],
65
+ padding="max_length",
66
+ max_length=pipeline.tokenizer.model_max_length,
67
+ truncation=True,
68
+ return_tensors="pt",
69
+ )
70
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
71
+ context = torch.cat([uncond_embeddings, text_embeddings])
72
+
73
+ return context
74
+
75
+
76
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
77
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
78
+ timestep, next_timestep = min(
79
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
80
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
81
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
82
+ beta_prod_t = 1 - alpha_prod_t
83
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
84
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
85
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
86
+ return next_sample
87
+
88
+
89
+ def get_noise_pred_single(latents, t, context, unet, normal_infer=False):
90
+ bs = latents.shape[0] # (b*f, c, h, w) or (b, c, f, h, w)
91
+ if bs != context.shape[0]:
92
+ context = context.repeat(bs, 1, 1) # (b*f, len, dim)
93
+ noise_pred = unet(latents, t, encoder_hidden_states=context, normal_infer=normal_infer)["sample"]
94
+ return noise_pred
95
+
96
+
97
+ @torch.no_grad()
98
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, normal_infer=False):
99
+ context = init_prompt(prompt, pipeline)
100
+ uncond_embeddings, cond_embeddings = context.chunk(2)
101
+ all_latent = [latent]
102
+ latent = latent.clone().detach()
103
+ for i in tqdm(range(num_inv_steps)):
104
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
105
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet, normal_infer=normal_infer)
106
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
107
+ all_latent.append(latent)
108
+ return all_latent
109
+
110
+
111
+ @torch.no_grad()
112
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", normal_infer=False):
113
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, normal_infer=normal_infer)
114
+ return ddim_latents