bubbliiiing commited on
Commit
e262715
1 Parent(s): 788d423
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
app.py CHANGED
@@ -11,9 +11,9 @@ if __name__ == "__main__":
11
  server_port = 7860
12
 
13
  # Params below is used when ui_mode = "modelscope"
14
- edition = "v2"
15
- config_path = "config/easyanimate_video_magvit_motion_module_v2.yaml"
16
- model_name = "models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
17
  savedir_sample = "samples"
18
 
19
  if ui_mode == "modelscope":
 
11
  server_port = 7860
12
 
13
  # Params below is used when ui_mode = "modelscope"
14
+ edition = "v3"
15
+ config_path = "config/easyanimate_video_slicevae_motion_module_v3.yaml"
16
+ model_name = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-InP-512x512"
17
  savedir_sample = "samples"
18
 
19
  if ui_mode == "modelscope":
easyanimate/api/api.py CHANGED
@@ -1,10 +1,14 @@
1
  import io
 
2
  import base64
3
  import torch
4
  import gradio as gr
 
 
5
 
6
  from fastapi import FastAPI
7
  from io import BytesIO
 
8
 
9
  # Function to encode a file to Base64
10
  def encode_file_to_base64(file_path):
@@ -59,16 +63,34 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
59
  lora_model_path = datas.get('lora_model_path', 'none')
60
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
61
  prompt_textbox = datas.get('prompt_textbox', None)
62
- negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
63
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
64
  sample_step_slider = datas.get('sample_step_slider', 30)
 
65
  width_slider = datas.get('width_slider', 672)
66
  height_slider = datas.get('height_slider', 384)
 
67
  is_image = datas.get('is_image', False)
 
68
  length_slider = datas.get('length_slider', 144)
 
 
69
  cfg_scale_slider = datas.get('cfg_scale_slider', 6)
 
 
70
  seed_textbox = datas.get("seed_textbox", 43)
71
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
  save_sample_path, comment = controller.generate(
74
  "",
@@ -80,17 +102,29 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
80
  negative_prompt_textbox,
81
  sampler_dropdown,
82
  sample_step_slider,
 
83
  width_slider,
84
  height_slider,
85
- is_image,
 
86
  length_slider,
 
 
87
  cfg_scale_slider,
 
 
88
  seed_textbox,
89
  is_api = True,
90
  )
91
  except Exception as e:
 
92
  torch.cuda.empty_cache()
 
93
  save_sample_path = ""
94
  comment = f"Error. error information is {str(e)}"
95
-
96
- return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
 
 
 
 
 
1
  import io
2
+ import gc
3
  import base64
4
  import torch
5
  import gradio as gr
6
+ import tempfile
7
+ import hashlib
8
 
9
  from fastapi import FastAPI
10
  from io import BytesIO
11
+ from PIL import Image
12
 
13
  # Function to encode a file to Base64
14
  def encode_file_to_base64(file_path):
 
63
  lora_model_path = datas.get('lora_model_path', 'none')
64
  lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
65
  prompt_textbox = datas.get('prompt_textbox', None)
66
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.')
67
  sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
68
  sample_step_slider = datas.get('sample_step_slider', 30)
69
+ resize_method = datas.get('resize_method', "Generate by")
70
  width_slider = datas.get('width_slider', 672)
71
  height_slider = datas.get('height_slider', 384)
72
+ base_resolution = datas.get('base_resolution', 512)
73
  is_image = datas.get('is_image', False)
74
+ generation_method = datas.get('generation_method', False)
75
  length_slider = datas.get('length_slider', 144)
76
+ overlap_video_length = datas.get('overlap_video_length', 4)
77
+ partial_video_length = datas.get('partial_video_length', 72)
78
  cfg_scale_slider = datas.get('cfg_scale_slider', 6)
79
+ start_image = datas.get('start_image', None)
80
+ end_image = datas.get('end_image', None)
81
  seed_textbox = datas.get("seed_textbox", 43)
82
 
83
+ generation_method = "Image Generation" if is_image else generation_method
84
+
85
+ temp_directory = tempfile.gettempdir()
86
+ if start_image is not None:
87
+ start_image = base64.b64decode(start_image)
88
+ start_image = [Image.open(BytesIO(start_image))]
89
+
90
+ if end_image is not None:
91
+ end_image = base64.b64decode(end_image)
92
+ end_image = [Image.open(BytesIO(end_image))]
93
+
94
  try:
95
  save_sample_path, comment = controller.generate(
96
  "",
 
102
  negative_prompt_textbox,
103
  sampler_dropdown,
104
  sample_step_slider,
105
+ resize_method,
106
  width_slider,
107
  height_slider,
108
+ base_resolution,
109
+ generation_method,
110
  length_slider,
111
+ overlap_video_length,
112
+ partial_video_length,
113
  cfg_scale_slider,
114
+ start_image,
115
+ end_image,
116
  seed_textbox,
117
  is_api = True,
118
  )
119
  except Exception as e:
120
+ gc.collect()
121
  torch.cuda.empty_cache()
122
+ torch.cuda.ipc_collect()
123
  save_sample_path = ""
124
  comment = f"Error. error information is {str(e)}"
125
+ return {"message": comment}
126
+
127
+ if save_sample_path != "":
128
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
129
+ else:
130
+ return {"message": comment, "save_sample_path": save_sample_path}
easyanimate/api/post_infer.py CHANGED
@@ -26,7 +26,7 @@ def post_update_edition(edition, url='http://0.0.0.0:7860'):
26
  data = r.content.decode('utf-8')
27
  return data
28
 
29
- def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
30
  datas = json.dumps({
31
  "base_model_path": "none",
32
  "motion_module_path": "none",
@@ -38,7 +38,7 @@ def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'):
38
  "sample_step_slider": 30,
39
  "width_slider": 672,
40
  "height_slider": 384,
41
- "is_image": is_image,
42
  "length_slider": length_slider,
43
  "cfg_scale_slider": 6,
44
  "seed_textbox": 43,
@@ -55,29 +55,31 @@ if __name__ == '__main__':
55
  # -------------------------- #
56
  # Step 1: update edition
57
  # -------------------------- #
58
- edition = "v2"
59
  outputs = post_update_edition(edition)
60
  print('Output update edition: ', outputs)
61
 
62
  # -------------------------- #
63
  # Step 2: update edition
64
  # -------------------------- #
65
- diffusion_transformer_path = "/your-path/EasyAnimate/models/Diffusion_Transformer/EasyAnimateV2-XL-2-512x512"
66
  outputs = post_diffusion_transformer(diffusion_transformer_path)
67
  print('Output update edition: ', outputs)
68
 
69
  # -------------------------- #
70
  # Step 3: infer
71
  # -------------------------- #
72
- is_image = False
73
- length_slider = 27
74
- outputs = post_infer(is_image, length_slider)
 
75
 
76
  # Get decoded data
77
  outputs = json.loads(outputs)
78
  base64_encoding = outputs["base64_encoding"]
79
  decoded_data = base64.b64decode(base64_encoding)
80
 
 
81
  if is_image or length_slider == 1:
82
  file_path = "1.png"
83
  else:
 
26
  data = r.content.decode('utf-8')
27
  return data
28
 
29
+ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
30
  datas = json.dumps({
31
  "base_model_path": "none",
32
  "motion_module_path": "none",
 
38
  "sample_step_slider": 30,
39
  "width_slider": 672,
40
  "height_slider": 384,
41
+ "generation_method": "Video Generation",
42
  "length_slider": length_slider,
43
  "cfg_scale_slider": 6,
44
  "seed_textbox": 43,
 
55
  # -------------------------- #
56
  # Step 1: update edition
57
  # -------------------------- #
58
+ edition = "v3"
59
  outputs = post_update_edition(edition)
60
  print('Output update edition: ', outputs)
61
 
62
  # -------------------------- #
63
  # Step 2: update edition
64
  # -------------------------- #
65
+ diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-512x512"
66
  outputs = post_diffusion_transformer(diffusion_transformer_path)
67
  print('Output update edition: ', outputs)
68
 
69
  # -------------------------- #
70
  # Step 3: infer
71
  # -------------------------- #
72
+ # "Video Generation" and "Image Generation"
73
+ generation_method = "Video Generation"
74
+ length_slider = 72
75
+ outputs = post_infer(generation_method, length_slider)
76
 
77
  # Get decoded data
78
  outputs = json.loads(outputs)
79
  base64_encoding = outputs["base64_encoding"]
80
  decoded_data = base64.b64decode(base64_encoding)
81
 
82
+ is_image = True if generation_method == "Image Generation" else False
83
  if is_image or length_slider == 1:
84
  file_path = "1.png"
85
  else:
easyanimate/data/dataset_image_video.py CHANGED
@@ -12,6 +12,7 @@ import gc
12
  import numpy as np
13
  import torch
14
  import torchvision.transforms as transforms
 
15
  from func_timeout import func_timeout, FunctionTimedOut
16
  from decord import VideoReader
17
  from PIL import Image
@@ -21,6 +22,52 @@ from contextlib import contextmanager
21
 
22
  VIDEO_READER_TIMEOUT = 20
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class ImageVideoSampler(BatchSampler):
25
  """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
26
 
@@ -88,10 +135,11 @@ class ImageVideoDataset(Dataset):
88
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
89
  image_sample_size=512,
90
  video_repeat=0,
91
- text_drop_ratio=0.001,
92
  enable_bucket=False,
93
  video_length_drop_start=0.1,
94
  video_length_drop_end=0.9,
 
95
  ):
96
  # Loading annotations from files
97
  print(f"loading annotations from {ann_path} ...")
@@ -120,6 +168,8 @@ class ImageVideoDataset(Dataset):
120
  # TODO: enable bucket training
121
  self.enable_bucket = enable_bucket
122
  self.text_drop_ratio = text_drop_ratio
 
 
123
  self.video_length_drop_start = video_length_drop_start
124
  self.video_length_drop_end = video_length_drop_end
125
 
@@ -165,7 +215,7 @@ class ImageVideoDataset(Dataset):
165
 
166
  video_length = int(self.video_length_drop_end * len(video_reader))
167
  clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
168
- start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
169
  batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
170
 
171
  try:
@@ -230,6 +280,17 @@ class ImageVideoDataset(Dataset):
230
  except Exception as e:
231
  print(e, self.dataset[idx % len(self.dataset)])
232
  idx = random.randint(0, self.length-1)
 
 
 
 
 
 
 
 
 
 
 
233
  return sample
234
 
235
  if __name__ == "__main__":
@@ -238,4 +299,4 @@ if __name__ == "__main__":
238
  )
239
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
240
  for idx, batch in enumerate(dataloader):
241
- print(batch["pixel_values"].shape, len(batch["text"]))
 
12
  import numpy as np
13
  import torch
14
  import torchvision.transforms as transforms
15
+
16
  from func_timeout import func_timeout, FunctionTimedOut
17
  from decord import VideoReader
18
  from PIL import Image
 
22
 
23
  VIDEO_READER_TIMEOUT = 20
24
 
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ if f != 1:
29
+ mask_index = np.random.randint(1, 4)
30
+ else:
31
+ mask_index = np.random.randint(1, 2)
32
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
33
+
34
+ if mask_index == 0:
35
+ center_x = torch.randint(0, w, (1,)).item()
36
+ center_y = torch.randint(0, h, (1,)).item()
37
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
38
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
39
+
40
+ start_x = max(center_x - block_size_x // 2, 0)
41
+ end_x = min(center_x + block_size_x // 2, w)
42
+ start_y = max(center_y - block_size_y // 2, 0)
43
+ end_y = min(center_y + block_size_y // 2, h)
44
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
45
+ elif mask_index == 1:
46
+ mask[:, :, :, :] = 1
47
+ elif mask_index == 2:
48
+ mask_frame_index = np.random.randint(1, 5)
49
+ mask[mask_frame_index:, :, :, :] = 1
50
+ elif mask_index == 3:
51
+ mask_frame_index = np.random.randint(1, 5)
52
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
53
+ elif mask_index == 4:
54
+ center_x = torch.randint(0, w, (1,)).item()
55
+ center_y = torch.randint(0, h, (1,)).item()
56
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
57
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
58
+
59
+ start_x = max(center_x - block_size_x // 2, 0)
60
+ end_x = min(center_x + block_size_x // 2, w)
61
+ start_y = max(center_y - block_size_y // 2, 0)
62
+ end_y = min(center_y + block_size_y // 2, h)
63
+
64
+ mask_frame_before = np.random.randint(0, f // 2)
65
+ mask_frame_after = np.random.randint(f // 2, f)
66
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
67
+ else:
68
+ raise ValueError(f"The mask_index {mask_index} is not define")
69
+ return mask
70
+
71
  class ImageVideoSampler(BatchSampler):
72
  """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
73
 
 
135
  video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
136
  image_sample_size=512,
137
  video_repeat=0,
138
+ text_drop_ratio=-1,
139
  enable_bucket=False,
140
  video_length_drop_start=0.1,
141
  video_length_drop_end=0.9,
142
+ enable_inpaint=False,
143
  ):
144
  # Loading annotations from files
145
  print(f"loading annotations from {ann_path} ...")
 
168
  # TODO: enable bucket training
169
  self.enable_bucket = enable_bucket
170
  self.text_drop_ratio = text_drop_ratio
171
+ self.enable_inpaint = enable_inpaint
172
+
173
  self.video_length_drop_start = video_length_drop_start
174
  self.video_length_drop_end = video_length_drop_end
175
 
 
215
 
216
  video_length = int(self.video_length_drop_end * len(video_reader))
217
  clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
218
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
219
  batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
220
 
221
  try:
 
280
  except Exception as e:
281
  print(e, self.dataset[idx % len(self.dataset)])
282
  idx = random.randint(0, self.length-1)
283
+
284
+ if self.enable_inpaint and not self.enable_bucket:
285
+ mask = get_random_mask(pixel_values.size())
286
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
287
+ sample["mask_pixel_values"] = mask_pixel_values
288
+ sample["mask"] = mask
289
+
290
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
291
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
292
+ sample["clip_pixel_values"] = clip_pixel_values
293
+
294
  return sample
295
 
296
  if __name__ == "__main__":
 
299
  )
300
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
301
  for idx, batch in enumerate(dataloader):
302
+ print(batch["pixel_values"].shape, len(batch["text"]))
easyanimate/models/attention.py CHANGED
@@ -11,17 +11,25 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- import math
15
  from typing import Any, Dict, Optional
16
 
 
 
17
  import torch
18
  import torch.nn.functional as F
19
  import torch.nn.init as init
20
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
 
 
 
 
 
 
 
 
 
21
  from diffusers.models.attention import AdaLayerNorm, FeedForward
22
- from diffusers.models.attention_processor import Attention
23
  from diffusers.models.embeddings import SinusoidalPositionalEmbedding
24
- from diffusers.models.lora import LoRACompatibleLinear
25
  from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
26
  from diffusers.utils import USE_PEFT_BACKEND
27
  from diffusers.utils.import_utils import is_xformers_available
@@ -29,7 +37,8 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
29
  from einops import rearrange, repeat
30
  from torch import nn
31
 
32
- from .motion_module import get_motion_module
 
33
 
34
  if is_xformers_available():
35
  import xformers
@@ -38,6 +47,13 @@ else:
38
  xformers = None
39
 
40
 
 
 
 
 
 
 
 
41
  @maybe_allow_in_graph
42
  class GatedSelfAttentionDense(nn.Module):
43
  r"""
@@ -59,8 +75,8 @@ class GatedSelfAttentionDense(nn.Module):
59
  self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
60
  self.ff = FeedForward(query_dim, activation_fn="geglu")
61
 
62
- self.norm1 = nn.LayerNorm(query_dim)
63
- self.norm2 = nn.LayerNorm(query_dim)
64
 
65
  self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
66
  self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
@@ -80,14 +96,6 @@ class GatedSelfAttentionDense(nn.Module):
80
  return x
81
 
82
 
83
- def zero_module(module):
84
- # Zero out the parameters of a module and return it.
85
- for p in module.parameters():
86
- p.detach().zero_()
87
- return module
88
-
89
-
90
-
91
  class KVCompressionCrossAttention(nn.Module):
92
  r"""
93
  A cross attention layer.
@@ -154,7 +162,7 @@ class KVCompressionCrossAttention(nn.Module):
154
  stride=2,
155
  bias=True
156
  )
157
- self.kv_compression_norm = nn.LayerNorm(query_dim)
158
  init.constant_(self.kv_compression.weight, 1 / 4)
159
  if self.kv_compression.bias is not None:
160
  init.constant_(self.kv_compression.bias, 0)
@@ -410,6 +418,8 @@ class TemporalTransformerBlock(nn.Module):
410
  # motion module kwargs
411
  motion_module_type = "VanillaGrid",
412
  motion_module_kwargs = None,
 
 
413
  ):
414
  super().__init__()
415
  self.only_cross_attention = only_cross_attention
@@ -442,7 +452,7 @@ class TemporalTransformerBlock(nn.Module):
442
  elif self.use_ada_layer_norm_zero:
443
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
444
  else:
445
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
446
 
447
  self.kvcompression = kvcompression
448
  if kvcompression:
@@ -456,16 +466,28 @@ class TemporalTransformerBlock(nn.Module):
456
  upcast_attention=upcast_attention,
457
  )
458
  else:
459
- self.attn1 = Attention(
460
- query_dim=dim,
461
- heads=num_attention_heads,
462
- dim_head=attention_head_dim,
463
- dropout=dropout,
464
- bias=attention_bias,
465
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
466
- upcast_attention=upcast_attention,
467
- )
468
- print(self.attn1)
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  self.attn_temporal = get_motion_module(
471
  in_channels = dim,
@@ -481,27 +503,45 @@ class TemporalTransformerBlock(nn.Module):
481
  self.norm2 = (
482
  AdaLayerNorm(dim, num_embeds_ada_norm)
483
  if self.use_ada_layer_norm
484
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
485
  )
486
- self.attn2 = Attention(
487
- query_dim=dim,
488
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
489
- heads=num_attention_heads,
490
- dim_head=attention_head_dim,
491
- dropout=dropout,
492
- bias=attention_bias,
493
- upcast_attention=upcast_attention,
494
- ) # is self-attn if encoder_hidden_states is none
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  else:
496
  self.norm2 = None
497
  self.attn2 = None
498
 
499
  # 3. Feed-forward
500
  if not self.use_ada_layer_norm_single:
501
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
502
 
503
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
504
 
 
 
 
 
 
505
  # 4. Fuser
506
  if attention_type == "gated" or attention_type == "gated-text-image":
507
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
@@ -654,6 +694,9 @@ class TemporalTransformerBlock(nn.Module):
654
  )
655
  else:
656
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
 
 
 
657
 
658
  if self.use_ada_layer_norm_zero:
659
  ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -723,6 +766,8 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
723
  attention_type: str = "default",
724
  positional_embeddings: Optional[str] = None,
725
  num_positional_embeddings: Optional[int] = None,
 
 
726
  ):
727
  super().__init__()
728
  self.only_cross_attention = only_cross_attention
@@ -755,17 +800,30 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
755
  elif self.use_ada_layer_norm_zero:
756
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
757
  else:
758
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
759
 
760
- self.attn1 = Attention(
761
- query_dim=dim,
762
- heads=num_attention_heads,
763
- dim_head=attention_head_dim,
764
- dropout=dropout,
765
- bias=attention_bias,
766
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
767
- upcast_attention=upcast_attention,
768
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
  # 2. Cross-Attn
771
  if cross_attention_dim is not None or double_self_attention:
@@ -775,27 +833,45 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
775
  self.norm2 = (
776
  AdaLayerNorm(dim, num_embeds_ada_norm)
777
  if self.use_ada_layer_norm
778
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
779
  )
780
- self.attn2 = Attention(
781
- query_dim=dim,
782
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
783
- heads=num_attention_heads,
784
- dim_head=attention_head_dim,
785
- dropout=dropout,
786
- bias=attention_bias,
787
- upcast_attention=upcast_attention,
788
- ) # is self-attn if encoder_hidden_states is none
 
 
 
 
 
 
 
 
 
 
 
 
 
789
  else:
790
  self.norm2 = None
791
  self.attn2 = None
792
 
793
  # 3. Feed-forward
794
  if not self.use_ada_layer_norm_single:
795
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
796
 
797
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
798
 
 
 
 
 
 
799
  # 4. Fuser
800
  if attention_type == "gated" or attention_type == "gated-text-image":
801
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
@@ -927,6 +1003,9 @@ class SelfAttentionTemporalTransformerBlock(nn.Module):
927
  )
928
  else:
929
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
 
 
 
930
 
931
  if self.use_ada_layer_norm_zero:
932
  ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -997,6 +1076,8 @@ class KVCompressionTransformerBlock(nn.Module):
997
  positional_embeddings: Optional[str] = None,
998
  num_positional_embeddings: Optional[int] = None,
999
  kvcompression: Optional[bool] = False,
 
 
1000
  ):
1001
  super().__init__()
1002
  self.only_cross_attention = only_cross_attention
@@ -1029,7 +1110,7 @@ class KVCompressionTransformerBlock(nn.Module):
1029
  elif self.use_ada_layer_norm_zero:
1030
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1031
  else:
1032
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1033
 
1034
  self.kvcompression = kvcompression
1035
  if kvcompression:
@@ -1043,16 +1124,28 @@ class KVCompressionTransformerBlock(nn.Module):
1043
  upcast_attention=upcast_attention,
1044
  )
1045
  else:
1046
- self.attn1 = Attention(
1047
- query_dim=dim,
1048
- heads=num_attention_heads,
1049
- dim_head=attention_head_dim,
1050
- dropout=dropout,
1051
- bias=attention_bias,
1052
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1053
- upcast_attention=upcast_attention,
1054
- )
1055
- print(self.attn1)
 
 
 
 
 
 
 
 
 
 
 
 
1056
 
1057
  # 2. Cross-Attn
1058
  if cross_attention_dim is not None or double_self_attention:
@@ -1062,27 +1155,45 @@ class KVCompressionTransformerBlock(nn.Module):
1062
  self.norm2 = (
1063
  AdaLayerNorm(dim, num_embeds_ada_norm)
1064
  if self.use_ada_layer_norm
1065
- else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1066
  )
1067
- self.attn2 = Attention(
1068
- query_dim=dim,
1069
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1070
- heads=num_attention_heads,
1071
- dim_head=attention_head_dim,
1072
- dropout=dropout,
1073
- bias=attention_bias,
1074
- upcast_attention=upcast_attention,
1075
- ) # is self-attn if encoder_hidden_states is none
 
 
 
 
 
 
 
 
 
 
 
 
 
1076
  else:
1077
  self.norm2 = None
1078
  self.attn2 = None
1079
 
1080
  # 3. Feed-forward
1081
  if not self.use_ada_layer_norm_single:
1082
- self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1083
 
1084
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
1085
 
 
 
 
 
 
1086
  # 4. Fuser
1087
  if attention_type == "gated" or attention_type == "gated-text-image":
1088
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
@@ -1229,6 +1340,9 @@ class KVCompressionTransformerBlock(nn.Module):
1229
  )
1230
  else:
1231
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
 
 
 
1232
 
1233
  if self.use_ada_layer_norm_zero:
1234
  ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -1239,61 +1353,4 @@ class KVCompressionTransformerBlock(nn.Module):
1239
  if hidden_states.ndim == 4:
1240
  hidden_states = hidden_states.squeeze(1)
1241
 
1242
- return hidden_states
1243
-
1244
-
1245
- class FeedForward(nn.Module):
1246
- r"""
1247
- A feed-forward layer.
1248
-
1249
- Parameters:
1250
- dim (`int`): The number of channels in the input.
1251
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1252
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1253
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1254
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1255
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1256
- """
1257
-
1258
- def __init__(
1259
- self,
1260
- dim: int,
1261
- dim_out: Optional[int] = None,
1262
- mult: int = 4,
1263
- dropout: float = 0.0,
1264
- activation_fn: str = "geglu",
1265
- final_dropout: bool = False,
1266
- ):
1267
- super().__init__()
1268
- inner_dim = int(dim * mult)
1269
- dim_out = dim_out if dim_out is not None else dim
1270
- linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
1271
-
1272
- if activation_fn == "gelu":
1273
- act_fn = GELU(dim, inner_dim)
1274
- if activation_fn == "gelu-approximate":
1275
- act_fn = GELU(dim, inner_dim, approximate="tanh")
1276
- elif activation_fn == "geglu":
1277
- act_fn = GEGLU(dim, inner_dim)
1278
- elif activation_fn == "geglu-approximate":
1279
- act_fn = ApproximateGELU(dim, inner_dim)
1280
-
1281
- self.net = nn.ModuleList([])
1282
- # project in
1283
- self.net.append(act_fn)
1284
- # project dropout
1285
- self.net.append(nn.Dropout(dropout))
1286
- # project out
1287
- self.net.append(linear_cls(inner_dim, dim_out))
1288
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1289
- if final_dropout:
1290
- self.net.append(nn.Dropout(dropout))
1291
-
1292
- def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1293
- compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
1294
- for module in self.net:
1295
- if isinstance(module, compatible_cls):
1296
- hidden_states = module(hidden_states, scale)
1297
- else:
1298
- hidden_states = module(hidden_states)
1299
- return hidden_states
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  from typing import Any, Dict, Optional
15
 
16
+ import diffusers
17
+ import pkg_resources
18
  import torch
19
  import torch.nn.functional as F
20
  import torch.nn.init as init
21
+
22
+ installed_version = diffusers.__version__
23
+
24
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
25
+ from diffusers.models.attention_processor import (Attention,
26
+ AttnProcessor2_0,
27
+ HunyuanAttnProcessor2_0)
28
+ else:
29
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
30
+
31
  from diffusers.models.attention import AdaLayerNorm, FeedForward
 
32
  from diffusers.models.embeddings import SinusoidalPositionalEmbedding
 
33
  from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
34
  from diffusers.utils import USE_PEFT_BACKEND
35
  from diffusers.utils.import_utils import is_xformers_available
 
37
  from einops import rearrange, repeat
38
  from torch import nn
39
 
40
+ from .motion_module import PositionalEncoding, get_motion_module
41
+ from .norm import FP32LayerNorm
42
 
43
  if is_xformers_available():
44
  import xformers
 
47
  xformers = None
48
 
49
 
50
+ def zero_module(module):
51
+ # Zero out the parameters of a module and return it.
52
+ for p in module.parameters():
53
+ p.detach().zero_()
54
+ return module
55
+
56
+
57
  @maybe_allow_in_graph
58
  class GatedSelfAttentionDense(nn.Module):
59
  r"""
 
75
  self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
76
  self.ff = FeedForward(query_dim, activation_fn="geglu")
77
 
78
+ self.norm1 = FP32LayerNorm(query_dim)
79
+ self.norm2 = FP32LayerNorm(query_dim)
80
 
81
  self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
82
  self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
 
96
  return x
97
 
98
 
 
 
 
 
 
 
 
 
99
  class KVCompressionCrossAttention(nn.Module):
100
  r"""
101
  A cross attention layer.
 
162
  stride=2,
163
  bias=True
164
  )
165
+ self.kv_compression_norm = FP32LayerNorm(query_dim)
166
  init.constant_(self.kv_compression.weight, 1 / 4)
167
  if self.kv_compression.bias is not None:
168
  init.constant_(self.kv_compression.bias, 0)
 
418
  # motion module kwargs
419
  motion_module_type = "VanillaGrid",
420
  motion_module_kwargs = None,
421
+ qk_norm = False,
422
+ after_norm = False,
423
  ):
424
  super().__init__()
425
  self.only_cross_attention = only_cross_attention
 
452
  elif self.use_ada_layer_norm_zero:
453
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
454
  else:
455
+ self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
456
 
457
  self.kvcompression = kvcompression
458
  if kvcompression:
 
466
  upcast_attention=upcast_attention,
467
  )
468
  else:
469
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
470
+ self.attn1 = Attention(
471
+ query_dim=dim,
472
+ heads=num_attention_heads,
473
+ dim_head=attention_head_dim,
474
+ dropout=dropout,
475
+ bias=attention_bias,
476
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
477
+ upcast_attention=upcast_attention,
478
+ qk_norm="layer_norm" if qk_norm else None,
479
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
480
+ )
481
+ else:
482
+ self.attn1 = Attention(
483
+ query_dim=dim,
484
+ heads=num_attention_heads,
485
+ dim_head=attention_head_dim,
486
+ dropout=dropout,
487
+ bias=attention_bias,
488
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
489
+ upcast_attention=upcast_attention,
490
+ )
491
 
492
  self.attn_temporal = get_motion_module(
493
  in_channels = dim,
 
503
  self.norm2 = (
504
  AdaLayerNorm(dim, num_embeds_ada_norm)
505
  if self.use_ada_layer_norm
506
+ else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
507
  )
508
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
509
+ self.attn2 = Attention(
510
+ query_dim=dim,
511
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
512
+ heads=num_attention_heads,
513
+ dim_head=attention_head_dim,
514
+ dropout=dropout,
515
+ bias=attention_bias,
516
+ upcast_attention=upcast_attention,
517
+ qk_norm="layer_norm" if qk_norm else None,
518
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
519
+ ) # is self-attn if encoder_hidden_states is none
520
+ else:
521
+ self.attn2 = Attention(
522
+ query_dim=dim,
523
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
524
+ heads=num_attention_heads,
525
+ dim_head=attention_head_dim,
526
+ dropout=dropout,
527
+ bias=attention_bias,
528
+ upcast_attention=upcast_attention,
529
+ ) # is self-attn if encoder_hidden_states is none
530
  else:
531
  self.norm2 = None
532
  self.attn2 = None
533
 
534
  # 3. Feed-forward
535
  if not self.use_ada_layer_norm_single:
536
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
537
 
538
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
539
 
540
+ if after_norm:
541
+ self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
542
+ else:
543
+ self.norm4 = None
544
+
545
  # 4. Fuser
546
  if attention_type == "gated" or attention_type == "gated-text-image":
547
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
 
694
  )
695
  else:
696
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
697
+
698
+ if self.norm4 is not None:
699
+ ff_output = self.norm4(ff_output)
700
 
701
  if self.use_ada_layer_norm_zero:
702
  ff_output = gate_mlp.unsqueeze(1) * ff_output
 
766
  attention_type: str = "default",
767
  positional_embeddings: Optional[str] = None,
768
  num_positional_embeddings: Optional[int] = None,
769
+ qk_norm = False,
770
+ after_norm = False,
771
  ):
772
  super().__init__()
773
  self.only_cross_attention = only_cross_attention
 
800
  elif self.use_ada_layer_norm_zero:
801
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
802
  else:
803
+ self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
804
 
805
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
806
+ self.attn1 = Attention(
807
+ query_dim=dim,
808
+ heads=num_attention_heads,
809
+ dim_head=attention_head_dim,
810
+ dropout=dropout,
811
+ bias=attention_bias,
812
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
813
+ upcast_attention=upcast_attention,
814
+ qk_norm="layer_norm" if qk_norm else None,
815
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
816
+ )
817
+ else:
818
+ self.attn1 = Attention(
819
+ query_dim=dim,
820
+ heads=num_attention_heads,
821
+ dim_head=attention_head_dim,
822
+ dropout=dropout,
823
+ bias=attention_bias,
824
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
825
+ upcast_attention=upcast_attention,
826
+ )
827
 
828
  # 2. Cross-Attn
829
  if cross_attention_dim is not None or double_self_attention:
 
833
  self.norm2 = (
834
  AdaLayerNorm(dim, num_embeds_ada_norm)
835
  if self.use_ada_layer_norm
836
+ else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
837
  )
838
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
839
+ self.attn2 = Attention(
840
+ query_dim=dim,
841
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
842
+ heads=num_attention_heads,
843
+ dim_head=attention_head_dim,
844
+ dropout=dropout,
845
+ bias=attention_bias,
846
+ upcast_attention=upcast_attention,
847
+ qk_norm="layer_norm" if qk_norm else None,
848
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
849
+ ) # is self-attn if encoder_hidden_states is none
850
+ else:
851
+ self.attn2 = Attention(
852
+ query_dim=dim,
853
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
854
+ heads=num_attention_heads,
855
+ dim_head=attention_head_dim,
856
+ dropout=dropout,
857
+ bias=attention_bias,
858
+ upcast_attention=upcast_attention,
859
+ ) # is self-attn if encoder_hidden_states is none
860
  else:
861
  self.norm2 = None
862
  self.attn2 = None
863
 
864
  # 3. Feed-forward
865
  if not self.use_ada_layer_norm_single:
866
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
867
 
868
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
869
 
870
+ if after_norm:
871
+ self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
872
+ else:
873
+ self.norm4 = None
874
+
875
  # 4. Fuser
876
  if attention_type == "gated" or attention_type == "gated-text-image":
877
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
 
1003
  )
1004
  else:
1005
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
1006
+
1007
+ if self.norm4 is not None:
1008
+ ff_output = self.norm4(ff_output)
1009
 
1010
  if self.use_ada_layer_norm_zero:
1011
  ff_output = gate_mlp.unsqueeze(1) * ff_output
 
1076
  positional_embeddings: Optional[str] = None,
1077
  num_positional_embeddings: Optional[int] = None,
1078
  kvcompression: Optional[bool] = False,
1079
+ qk_norm = False,
1080
+ after_norm = False,
1081
  ):
1082
  super().__init__()
1083
  self.only_cross_attention = only_cross_attention
 
1110
  elif self.use_ada_layer_norm_zero:
1111
  self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
1112
  else:
1113
+ self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1114
 
1115
  self.kvcompression = kvcompression
1116
  if kvcompression:
 
1124
  upcast_attention=upcast_attention,
1125
  )
1126
  else:
1127
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
1128
+ self.attn1 = Attention(
1129
+ query_dim=dim,
1130
+ heads=num_attention_heads,
1131
+ dim_head=attention_head_dim,
1132
+ dropout=dropout,
1133
+ bias=attention_bias,
1134
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1135
+ upcast_attention=upcast_attention,
1136
+ qk_norm="layer_norm" if qk_norm else None,
1137
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
1138
+ )
1139
+ else:
1140
+ self.attn1 = Attention(
1141
+ query_dim=dim,
1142
+ heads=num_attention_heads,
1143
+ dim_head=attention_head_dim,
1144
+ dropout=dropout,
1145
+ bias=attention_bias,
1146
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1147
+ upcast_attention=upcast_attention,
1148
+ )
1149
 
1150
  # 2. Cross-Attn
1151
  if cross_attention_dim is not None or double_self_attention:
 
1155
  self.norm2 = (
1156
  AdaLayerNorm(dim, num_embeds_ada_norm)
1157
  if self.use_ada_layer_norm
1158
+ else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1159
  )
1160
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
1161
+ self.attn2 = Attention(
1162
+ query_dim=dim,
1163
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1164
+ heads=num_attention_heads,
1165
+ dim_head=attention_head_dim,
1166
+ dropout=dropout,
1167
+ bias=attention_bias,
1168
+ upcast_attention=upcast_attention,
1169
+ qk_norm="layer_norm" if qk_norm else None,
1170
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
1171
+ ) # is self-attn if encoder_hidden_states is none
1172
+ else:
1173
+ self.attn2 = Attention(
1174
+ query_dim=dim,
1175
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1176
+ heads=num_attention_heads,
1177
+ dim_head=attention_head_dim,
1178
+ dropout=dropout,
1179
+ bias=attention_bias,
1180
+ upcast_attention=upcast_attention,
1181
+ ) # is self-attn if encoder_hidden_states is none
1182
  else:
1183
  self.norm2 = None
1184
  self.attn2 = None
1185
 
1186
  # 3. Feed-forward
1187
  if not self.use_ada_layer_norm_single:
1188
+ self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1189
 
1190
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
1191
 
1192
+ if after_norm:
1193
+ self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1194
+ else:
1195
+ self.norm4 = None
1196
+
1197
  # 4. Fuser
1198
  if attention_type == "gated" or attention_type == "gated-text-image":
1199
  self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
 
1340
  )
1341
  else:
1342
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
1343
+
1344
+ if self.norm4 is not None:
1345
+ ff_output = self.norm4(ff_output)
1346
 
1347
  if self.use_ada_layer_norm_zero:
1348
  ff_output = gate_mlp.unsqueeze(1) * ff_output
 
1353
  if hidden_states.ndim == 4:
1354
  hidden_states = hidden_states.squeeze(1)
1355
 
1356
+ return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
easyanimate/models/autoencoder_magvit.py CHANGED
@@ -17,7 +17,12 @@ import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
20
- from diffusers.loaders import FromOriginalVAEMixin
 
 
 
 
 
21
  from diffusers.models.attention_processor import (
22
  ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
23
  AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
@@ -93,6 +98,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
93
  norm_num_groups: int = 32,
94
  scaling_factor: float = 0.1825,
95
  slice_compression_vae=False,
 
96
  mini_batch_encoder=9,
97
  mini_batch_decoder=3,
98
  ):
@@ -145,8 +151,8 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
145
  self.mini_batch_encoder = mini_batch_encoder
146
  self.mini_batch_decoder = mini_batch_decoder
147
  self.use_slicing = False
148
- self.use_tiling = False
149
- self.tile_sample_min_size = 256
150
  self.tile_overlap_factor = 0.25
151
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
152
  self.scaling_factor = scaling_factor
 
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+
21
+ try:
22
+ from diffusers.loaders import FromOriginalVAEMixin
23
+ except:
24
+ from diffusers.loaders import FromOriginalModelMixin as FromOriginalVAEMixin
25
+
26
  from diffusers.models.attention_processor import (
27
  ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
28
  AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
 
98
  norm_num_groups: int = 32,
99
  scaling_factor: float = 0.1825,
100
  slice_compression_vae=False,
101
+ use_tiling=False,
102
  mini_batch_encoder=9,
103
  mini_batch_decoder=3,
104
  ):
 
151
  self.mini_batch_encoder = mini_batch_encoder
152
  self.mini_batch_decoder = mini_batch_decoder
153
  self.use_slicing = False
154
+ self.use_tiling = use_tiling
155
+ self.tile_sample_min_size = 384
156
  self.tile_overlap_factor = 0.25
157
  self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
158
  self.scaling_factor = scaling_factor
easyanimate/models/motion_module.py CHANGED
@@ -1,248 +1,33 @@
1
  """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
  """
3
  import math
4
- from typing import Any, Callable, List, Optional, Tuple, Union
5
 
 
 
6
  import torch
7
- import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
8
  from diffusers.models.attention import FeedForward
9
  from diffusers.utils.import_utils import is_xformers_available
10
  from einops import rearrange, repeat
11
  from torch import nn
12
 
 
 
13
  if is_xformers_available():
14
  import xformers
15
  import xformers.ops
16
  else:
17
  xformers = None
18
 
19
- class CrossAttention(nn.Module):
20
- r"""
21
- A cross attention layer.
22
-
23
- Parameters:
24
- query_dim (`int`): The number of channels in the query.
25
- cross_attention_dim (`int`, *optional*):
26
- The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
27
- heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
28
- dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
29
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
30
- bias (`bool`, *optional*, defaults to False):
31
- Set to `True` for the query, key, and value linear layers to contain a bias parameter.
32
- """
33
-
34
- def __init__(
35
- self,
36
- query_dim: int,
37
- cross_attention_dim: Optional[int] = None,
38
- heads: int = 8,
39
- dim_head: int = 64,
40
- dropout: float = 0.0,
41
- bias=False,
42
- upcast_attention: bool = False,
43
- upcast_softmax: bool = False,
44
- added_kv_proj_dim: Optional[int] = None,
45
- norm_num_groups: Optional[int] = None,
46
- ):
47
- super().__init__()
48
- inner_dim = dim_head * heads
49
- cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
50
- self.upcast_attention = upcast_attention
51
- self.upcast_softmax = upcast_softmax
52
-
53
- self.scale = dim_head**-0.5
54
-
55
- self.heads = heads
56
- # for slice_size > 0 the attention score computation
57
- # is split across the batch axis to save memory
58
- # You can set slice_size with `set_attention_slice`
59
- self.sliceable_head_dim = heads
60
- self._slice_size = None
61
- self._use_memory_efficient_attention_xformers = False
62
- self.added_kv_proj_dim = added_kv_proj_dim
63
-
64
- if norm_num_groups is not None:
65
- self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
66
- else:
67
- self.group_norm = None
68
-
69
- self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
70
- self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
71
- self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
72
-
73
- if self.added_kv_proj_dim is not None:
74
- self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
75
- self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
76
-
77
- self.to_out = nn.ModuleList([])
78
- self.to_out.append(nn.Linear(inner_dim, query_dim))
79
- self.to_out.append(nn.Dropout(dropout))
80
-
81
- def set_use_memory_efficient_attention_xformers(
82
- self, valid: bool, attention_op: Optional[Callable] = None
83
- ) -> None:
84
- self._use_memory_efficient_attention_xformers = valid
85
-
86
- def reshape_heads_to_batch_dim(self, tensor):
87
- batch_size, seq_len, dim = tensor.shape
88
- head_size = self.heads
89
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
90
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
91
- return tensor
92
-
93
- def reshape_batch_dim_to_heads(self, tensor):
94
- batch_size, seq_len, dim = tensor.shape
95
- head_size = self.heads
96
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
97
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
98
- return tensor
99
-
100
- def set_attention_slice(self, slice_size):
101
- if slice_size is not None and slice_size > self.sliceable_head_dim:
102
- raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
103
-
104
- self._slice_size = slice_size
105
-
106
- def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
107
- batch_size, sequence_length, _ = hidden_states.shape
108
-
109
- encoder_hidden_states = encoder_hidden_states
110
-
111
- if self.group_norm is not None:
112
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
113
-
114
- query = self.to_q(hidden_states)
115
- dim = query.shape[-1]
116
- query = self.reshape_heads_to_batch_dim(query)
117
-
118
- if self.added_kv_proj_dim is not None:
119
- key = self.to_k(hidden_states)
120
- value = self.to_v(hidden_states)
121
- encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
122
- encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
123
-
124
- key = self.reshape_heads_to_batch_dim(key)
125
- value = self.reshape_heads_to_batch_dim(value)
126
- encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
127
- encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
128
-
129
- key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
130
- value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
131
- else:
132
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
133
- key = self.to_k(encoder_hidden_states)
134
- value = self.to_v(encoder_hidden_states)
135
-
136
- key = self.reshape_heads_to_batch_dim(key)
137
- value = self.reshape_heads_to_batch_dim(value)
138
-
139
- if attention_mask is not None:
140
- if attention_mask.shape[-1] != query.shape[1]:
141
- target_length = query.shape[1]
142
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
143
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
144
-
145
- # attention, what we cannot get enough of
146
- if self._use_memory_efficient_attention_xformers:
147
- hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
148
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
149
- hidden_states = hidden_states.to(query.dtype)
150
- else:
151
- if self._slice_size is None or query.shape[0] // self._slice_size == 1:
152
- hidden_states = self._attention(query, key, value, attention_mask)
153
- else:
154
- hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
155
-
156
- # linear proj
157
- hidden_states = self.to_out[0](hidden_states)
158
-
159
- # dropout
160
- hidden_states = self.to_out[1](hidden_states)
161
- return hidden_states
162
-
163
- def _attention(self, query, key, value, attention_mask=None):
164
- if self.upcast_attention:
165
- query = query.float()
166
- key = key.float()
167
-
168
- attention_scores = torch.baddbmm(
169
- torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
170
- query,
171
- key.transpose(-1, -2),
172
- beta=0,
173
- alpha=self.scale,
174
- )
175
-
176
- if attention_mask is not None:
177
- attention_scores = attention_scores + attention_mask
178
-
179
- if self.upcast_softmax:
180
- attention_scores = attention_scores.float()
181
-
182
- attention_probs = attention_scores.softmax(dim=-1)
183
-
184
- # cast back to the original dtype
185
- attention_probs = attention_probs.to(value.dtype)
186
-
187
- # compute attention output
188
- hidden_states = torch.bmm(attention_probs, value)
189
-
190
- # reshape hidden_states
191
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
192
- return hidden_states
193
-
194
- def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
195
- batch_size_attention = query.shape[0]
196
- hidden_states = torch.zeros(
197
- (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
198
- )
199
- slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
200
- for i in range(hidden_states.shape[0] // slice_size):
201
- start_idx = i * slice_size
202
- end_idx = (i + 1) * slice_size
203
-
204
- query_slice = query[start_idx:end_idx]
205
- key_slice = key[start_idx:end_idx]
206
-
207
- if self.upcast_attention:
208
- query_slice = query_slice.float()
209
- key_slice = key_slice.float()
210
-
211
- attn_slice = torch.baddbmm(
212
- torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
213
- query_slice,
214
- key_slice.transpose(-1, -2),
215
- beta=0,
216
- alpha=self.scale,
217
- )
218
-
219
- if attention_mask is not None:
220
- attn_slice = attn_slice + attention_mask[start_idx:end_idx]
221
-
222
- if self.upcast_softmax:
223
- attn_slice = attn_slice.float()
224
-
225
- attn_slice = attn_slice.softmax(dim=-1)
226
-
227
- # cast back to the original dtype
228
- attn_slice = attn_slice.to(value.dtype)
229
- attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
230
-
231
- hidden_states[start_idx:end_idx] = attn_slice
232
-
233
- # reshape hidden_states
234
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
235
- return hidden_states
236
-
237
- def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
238
- # TODO attention_mask
239
- query = query.contiguous()
240
- key = key.contiguous()
241
- value = value.contiguous()
242
- hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
243
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
244
- return hidden_states
245
-
246
  def zero_module(module):
247
  # Zero out the parameters of a module and return it.
248
  for p in module.parameters():
@@ -275,6 +60,11 @@ class VanillaTemporalModule(nn.Module):
275
  zero_initialize = True,
276
  block_size = 1,
277
  grid = False,
 
 
 
 
 
278
  ):
279
  super().__init__()
280
 
@@ -289,17 +79,87 @@ class VanillaTemporalModule(nn.Module):
289
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
290
  grid=grid,
291
  block_size=block_size,
 
 
292
  )
 
 
 
 
 
 
293
  if zero_initialize:
294
  self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
 
 
295
 
296
  def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
297
  hidden_states = input_tensor
298
  hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
 
 
299
 
300
  output = hidden_states
301
  return output
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  class TemporalTransformer3DModel(nn.Module):
304
  def __init__(
305
  self,
@@ -321,6 +181,8 @@ class TemporalTransformer3DModel(nn.Module):
321
  temporal_position_encoding_max_len = 4096,
322
  grid = False,
323
  block_size = 1,
 
 
324
  ):
325
  super().__init__()
326
 
@@ -348,6 +210,8 @@ class TemporalTransformer3DModel(nn.Module):
348
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
349
  block_size=block_size,
350
  grid=grid,
 
 
351
  )
352
  for d in range(num_layers)
353
  ]
@@ -398,6 +262,8 @@ class TemporalTransformerBlock(nn.Module):
398
  temporal_position_encoding_max_len = 4096,
399
  block_size = 1,
400
  grid = False,
 
 
401
  ):
402
  super().__init__()
403
 
@@ -422,15 +288,36 @@ class TemporalTransformerBlock(nn.Module):
422
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
423
  block_size=block_size,
424
  grid=grid,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  )
426
  )
427
- norms.append(nn.LayerNorm(dim))
428
 
429
  self.attention_blocks = nn.ModuleList(attention_blocks)
430
  self.norms = nn.ModuleList(norms)
431
 
432
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
433
- self.ff_norm = nn.LayerNorm(dim)
434
 
435
  def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
436
  for attention_block, norm in zip(self.attention_blocks, self.norms):
@@ -468,7 +355,7 @@ class PositionalEncoding(nn.Module):
468
  x = x + self.pe[:, :x.size(1)]
469
  return self.dropout(x)
470
 
471
- class VersatileAttention(CrossAttention):
472
  def __init__(
473
  self,
474
  attention_mode = None,
@@ -477,21 +364,23 @@ class VersatileAttention(CrossAttention):
477
  temporal_position_encoding_max_len = 4096,
478
  grid = False,
479
  block_size = 1,
 
480
  *args, **kwargs
481
  ):
482
  super().__init__(*args, **kwargs)
483
- assert attention_mode == "Temporal"
484
 
485
  self.attention_mode = attention_mode
486
  self.is_cross_attention = kwargs["cross_attention_dim"] is not None
487
 
488
  self.block_size = block_size
489
  self.grid = grid
 
490
  self.pos_encoder = PositionalEncoding(
491
  kwargs["query_dim"],
492
  dropout=0.,
493
  max_len=temporal_position_encoding_max_len
494
- ) if (temporal_position_encoding and attention_mode == "Temporal") else None
495
 
496
  def extra_repr(self):
497
  return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
@@ -503,8 +392,13 @@ class VersatileAttention(CrossAttention):
503
  # for add pos_encoder
504
  _, before_d, _c = hidden_states.size()
505
  hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
506
- if self.pos_encoder is not None:
507
- hidden_states = self.pos_encoder(hidden_states)
 
 
 
 
 
508
 
509
  if self.grid:
510
  hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
@@ -515,61 +409,36 @@ class VersatileAttention(CrossAttention):
515
  else:
516
  d = before_d
517
  encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
 
 
 
 
 
 
 
518
  else:
519
  raise NotImplementedError
520
 
521
- encoder_hidden_states = encoder_hidden_states
522
-
523
- if self.group_norm is not None:
524
- hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
525
-
526
- query = self.to_q(hidden_states)
527
- dim = query.shape[-1]
528
- query = self.reshape_heads_to_batch_dim(query)
529
-
530
- if self.added_kv_proj_dim is not None:
531
- raise NotImplementedError
532
-
533
  encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
534
- key = self.to_k(encoder_hidden_states)
535
- value = self.to_v(encoder_hidden_states)
536
-
537
- key = self.reshape_heads_to_batch_dim(key)
538
- value = self.reshape_heads_to_batch_dim(value)
539
-
540
- if attention_mask is not None:
541
- if attention_mask.shape[-1] != query.shape[1]:
542
- target_length = query.shape[1]
543
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
544
- attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
545
 
546
  bs = 512
547
  new_hidden_states = []
548
- for i in range(0, query.shape[0], bs):
549
- # attention, what we cannot get enough of
550
- if self._use_memory_efficient_attention_xformers:
551
- hidden_states = self._memory_efficient_attention_xformers(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
552
- # Some versions of xformers return output in fp32, cast it back to the dtype of the input
553
- hidden_states = hidden_states.to(query.dtype)
554
- else:
555
- if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
556
- hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
557
- else:
558
- hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
559
- new_hidden_states.append(hidden_states)
560
  hidden_states = torch.cat(new_hidden_states, dim = 0)
561
 
562
- # linear proj
563
- hidden_states = self.to_out[0](hidden_states)
564
-
565
- # dropout
566
- hidden_states = self.to_out[1](hidden_states)
567
-
568
  if self.attention_mode == "Temporal":
569
  hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
570
  if self.grid:
571
  hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
572
  hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
573
  hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
 
 
574
 
575
  return hidden_states
 
1
  """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
  """
3
  import math
 
4
 
5
+ import diffusers
6
+ import pkg_resources
7
  import torch
8
+
9
+ installed_version = diffusers.__version__
10
+
11
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
12
+ from diffusers.models.attention_processor import (Attention,
13
+ AttnProcessor2_0,
14
+ HunyuanAttnProcessor2_0)
15
+ else:
16
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
17
+
18
  from diffusers.models.attention import FeedForward
19
  from diffusers.utils.import_utils import is_xformers_available
20
  from einops import rearrange, repeat
21
  from torch import nn
22
 
23
+ from .norm import FP32LayerNorm
24
+
25
  if is_xformers_available():
26
  import xformers
27
  import xformers.ops
28
  else:
29
  xformers = None
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def zero_module(module):
32
  # Zero out the parameters of a module and return it.
33
  for p in module.parameters():
 
60
  zero_initialize = True,
61
  block_size = 1,
62
  grid = False,
63
+ remove_time_embedding_in_photo = False,
64
+
65
+ global_num_attention_heads = 16,
66
+ global_attention = False,
67
+ qk_norm = False,
68
  ):
69
  super().__init__()
70
 
 
79
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
80
  grid=grid,
81
  block_size=block_size,
82
+ remove_time_embedding_in_photo=remove_time_embedding_in_photo,
83
+ qk_norm=qk_norm,
84
  )
85
+ self.global_transformer = GlobalTransformer3DModel(
86
+ in_channels=in_channels,
87
+ num_attention_heads=global_num_attention_heads,
88
+ attention_head_dim=in_channels // global_num_attention_heads // temporal_attention_dim_div,
89
+ qk_norm=qk_norm,
90
+ ) if global_attention else None
91
  if zero_initialize:
92
  self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
93
+ if global_attention:
94
+ self.global_transformer.proj_out = zero_module(self.global_transformer.proj_out)
95
 
96
  def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
97
  hidden_states = input_tensor
98
  hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
99
+ if self.global_transformer is not None:
100
+ hidden_states = self.global_transformer(hidden_states)
101
 
102
  output = hidden_states
103
  return output
104
 
105
+ class GlobalTransformer3DModel(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels,
109
+ num_attention_heads,
110
+ attention_head_dim,
111
+ dropout = 0.0,
112
+ attention_bias = False,
113
+ upcast_attention = False,
114
+ qk_norm = False,
115
+ ):
116
+ super().__init__()
117
+
118
+ inner_dim = num_attention_heads * attention_head_dim
119
+
120
+ self.norm1 = FP32LayerNorm(inner_dim)
121
+ self.proj_in = nn.Linear(in_channels, inner_dim)
122
+ self.norm2 = FP32LayerNorm(inner_dim)
123
+ if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
124
+ self.attention = Attention(
125
+ query_dim=inner_dim,
126
+ heads=num_attention_heads,
127
+ dim_head=attention_head_dim,
128
+ dropout=dropout,
129
+ bias=attention_bias,
130
+ upcast_attention=upcast_attention,
131
+ qk_norm="layer_norm" if qk_norm else None,
132
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
133
+ )
134
+ else:
135
+ self.attention = Attention(
136
+ query_dim=inner_dim,
137
+ heads=num_attention_heads,
138
+ dim_head=attention_head_dim,
139
+ dropout=dropout,
140
+ bias=attention_bias,
141
+ upcast_attention=upcast_attention,
142
+ )
143
+ self.proj_out = nn.Linear(inner_dim, in_channels)
144
+
145
+ def forward(self, hidden_states):
146
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
147
+ video_length, height, width = hidden_states.shape[2], hidden_states.shape[3], hidden_states.shape[4]
148
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
149
+
150
+ residual = hidden_states
151
+ hidden_states = self.norm1(hidden_states)
152
+ hidden_states = self.proj_in(hidden_states)
153
+
154
+ # Attention Blocks
155
+ hidden_states = self.norm2(hidden_states)
156
+ hidden_states = self.attention(hidden_states)
157
+ hidden_states = self.proj_out(hidden_states)
158
+
159
+ output = hidden_states + residual
160
+ output = rearrange(output, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
161
+ return output
162
+
163
  class TemporalTransformer3DModel(nn.Module):
164
  def __init__(
165
  self,
 
181
  temporal_position_encoding_max_len = 4096,
182
  grid = False,
183
  block_size = 1,
184
+ remove_time_embedding_in_photo = False,
185
+ qk_norm = False,
186
  ):
187
  super().__init__()
188
 
 
210
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
211
  block_size=block_size,
212
  grid=grid,
213
+ remove_time_embedding_in_photo=remove_time_embedding_in_photo,
214
+ qk_norm=qk_norm
215
  )
216
  for d in range(num_layers)
217
  ]
 
262
  temporal_position_encoding_max_len = 4096,
263
  block_size = 1,
264
  grid = False,
265
+ remove_time_embedding_in_photo = False,
266
+ qk_norm = False,
267
  ):
268
  super().__init__()
269
 
 
288
  temporal_position_encoding_max_len=temporal_position_encoding_max_len,
289
  block_size=block_size,
290
  grid=grid,
291
+ remove_time_embedding_in_photo=remove_time_embedding_in_photo,
292
+ qk_norm="layer_norm" if qk_norm else None,
293
+ processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
294
+ ) if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2") else \
295
+ VersatileAttention(
296
+ attention_mode=block_name.split("_")[0],
297
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
298
+
299
+ query_dim=dim,
300
+ heads=num_attention_heads,
301
+ dim_head=attention_head_dim,
302
+ dropout=dropout,
303
+ bias=attention_bias,
304
+ upcast_attention=upcast_attention,
305
+
306
+ cross_frame_attention_mode=cross_frame_attention_mode,
307
+ temporal_position_encoding=temporal_position_encoding,
308
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
309
+ block_size=block_size,
310
+ grid=grid,
311
+ remove_time_embedding_in_photo=remove_time_embedding_in_photo,
312
  )
313
  )
314
+ norms.append(FP32LayerNorm(dim))
315
 
316
  self.attention_blocks = nn.ModuleList(attention_blocks)
317
  self.norms = nn.ModuleList(norms)
318
 
319
  self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
320
+ self.ff_norm = FP32LayerNorm(dim)
321
 
322
  def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
323
  for attention_block, norm in zip(self.attention_blocks, self.norms):
 
355
  x = x + self.pe[:, :x.size(1)]
356
  return self.dropout(x)
357
 
358
+ class VersatileAttention(Attention):
359
  def __init__(
360
  self,
361
  attention_mode = None,
 
364
  temporal_position_encoding_max_len = 4096,
365
  grid = False,
366
  block_size = 1,
367
+ remove_time_embedding_in_photo = False,
368
  *args, **kwargs
369
  ):
370
  super().__init__(*args, **kwargs)
371
+ assert attention_mode == "Temporal" or attention_mode == "Global"
372
 
373
  self.attention_mode = attention_mode
374
  self.is_cross_attention = kwargs["cross_attention_dim"] is not None
375
 
376
  self.block_size = block_size
377
  self.grid = grid
378
+ self.remove_time_embedding_in_photo = remove_time_embedding_in_photo
379
  self.pos_encoder = PositionalEncoding(
380
  kwargs["query_dim"],
381
  dropout=0.,
382
  max_len=temporal_position_encoding_max_len
383
+ ) if (temporal_position_encoding and attention_mode == "Temporal") or (temporal_position_encoding and attention_mode == "Global") else None
384
 
385
  def extra_repr(self):
386
  return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
 
392
  # for add pos_encoder
393
  _, before_d, _c = hidden_states.size()
394
  hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
395
+
396
+ if self.remove_time_embedding_in_photo:
397
+ if self.pos_encoder is not None and video_length > 1:
398
+ hidden_states = self.pos_encoder(hidden_states)
399
+ else:
400
+ if self.pos_encoder is not None:
401
+ hidden_states = self.pos_encoder(hidden_states)
402
 
403
  if self.grid:
404
  hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
 
409
  else:
410
  d = before_d
411
  encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
412
+ elif self.attention_mode == "Global":
413
+ # for add pos_encoder
414
+ _, d, _c = hidden_states.size()
415
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
416
+ if self.pos_encoder is not None:
417
+ hidden_states = self.pos_encoder(hidden_states)
418
+ hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", f=video_length, d=d)
419
  else:
420
  raise NotImplementedError
421
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  bs = 512
425
  new_hidden_states = []
426
+ for i in range(0, hidden_states.shape[0], bs):
427
+ __hidden_states = super().forward(
428
+ hidden_states[i : i + bs],
429
+ encoder_hidden_states=encoder_hidden_states[i : i + bs],
430
+ attention_mask=attention_mask
431
+ )
432
+ new_hidden_states.append(__hidden_states)
 
 
 
 
 
433
  hidden_states = torch.cat(new_hidden_states, dim = 0)
434
 
 
 
 
 
 
 
435
  if self.attention_mode == "Temporal":
436
  hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
437
  if self.grid:
438
  hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
439
  hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
440
  hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
441
+ elif self.attention_mode == "Global":
442
+ hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length, d=d)
443
 
444
  return hidden_states
easyanimate/models/norm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
6
+ from torch import nn
7
+
8
+
9
+ def zero_module(module):
10
+ # Zero out the parameters of a module and return it.
11
+ for p in module.parameters():
12
+ p.detach().zero_()
13
+ return module
14
+
15
+
16
+ class FP32LayerNorm(nn.LayerNorm):
17
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
18
+ origin_dtype = inputs.dtype
19
+ if hasattr(self, 'weight') and self.weight is not None:
20
+ return F.layer_norm(
21
+ inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
22
+ ).to(origin_dtype)
23
+ else:
24
+ return F.layer_norm(
25
+ inputs.float(), self.normalized_shape, None, None, self.eps
26
+ ).to(origin_dtype)
27
+
28
+ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
29
+ """
30
+ For PixArt-Alpha.
31
+
32
+ Reference:
33
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
34
+ """
35
+
36
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
37
+ super().__init__()
38
+
39
+ self.outdim = size_emb_dim
40
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
41
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
42
+
43
+ self.use_additional_conditions = use_additional_conditions
44
+ if use_additional_conditions:
45
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
46
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
47
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
48
+
49
+ self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
50
+ self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
51
+
52
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
53
+ timesteps_proj = self.time_proj(timestep)
54
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
55
+
56
+ if self.use_additional_conditions:
57
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
58
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
59
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
60
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
61
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
62
+ else:
63
+ conditioning = timesteps_emb
64
+
65
+ return conditioning
66
+
67
+ class AdaLayerNormSingle(nn.Module):
68
+ r"""
69
+ Norm layer adaptive layer norm single (adaLN-single).
70
+
71
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
72
+
73
+ Parameters:
74
+ embedding_dim (`int`): The size of each embedding vector.
75
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
76
+ """
77
+
78
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
79
+ super().__init__()
80
+
81
+ self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
82
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
83
+ )
84
+
85
+ self.silu = nn.SiLU()
86
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
87
+
88
+ def forward(
89
+ self,
90
+ timestep: torch.Tensor,
91
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
92
+ batch_size: Optional[int] = None,
93
+ hidden_dtype: Optional[torch.dtype] = None,
94
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
95
+ # No modulation happening here.
96
+ embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
97
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
easyanimate/models/patch.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from typing import Optional
2
 
3
  import numpy as np
4
  import torch
5
  import torch.nn.functional as F
6
  import torch.nn.init as init
7
- import math
8
  from einops import rearrange
9
  from torch import nn
10
 
 
1
+ import math
2
  from typing import Optional
3
 
4
  import numpy as np
5
  import torch
6
  import torch.nn.functional as F
7
  import torch.nn.init as init
 
8
  from einops import rearrange
9
  from torch import nn
10
 
easyanimate/models/transformer3d.py CHANGED
@@ -15,26 +15,30 @@ import json
15
  import math
16
  import os
17
  from dataclasses import dataclass
18
- from typing import Any, Dict, Optional
19
 
20
  import numpy as np
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.nn.init as init
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
- from diffusers.models.attention import BasicTransformerBlock
26
- from diffusers.models.embeddings import PatchEmbed, Timesteps, TimestepEmbedding
 
27
  from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
28
  from diffusers.models.modeling_utils import ModelMixin
29
- from diffusers.models.normalization import AdaLayerNormSingle
30
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
 
 
31
  from einops import rearrange
32
  from torch import nn
33
- from typing import Dict, Optional, Tuple
34
 
35
  from .attention import (SelfAttentionTemporalTransformerBlock,
36
  TemporalTransformerBlock)
37
- from .patch import Patch1D, PatchEmbed3D, PatchEmbedF3D, UnPatch1D, TemporalUpsampler3D, CasualPatchEmbed3D
 
 
38
 
39
  try:
40
  from diffusers.models.embeddings import PixArtAlphaTextProjection
@@ -48,77 +52,25 @@ def zero_module(module):
48
  p.detach().zero_()
49
  return module
50
 
51
- class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
52
- """
53
- For PixArt-Alpha.
54
 
55
- Reference:
56
- https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
57
  """
 
58
 
59
- def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
60
- super().__init__()
61
-
62
- self.outdim = size_emb_dim
63
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
64
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
65
-
66
- self.use_additional_conditions = use_additional_conditions
67
- if use_additional_conditions:
68
- self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
69
- self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
70
- self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
71
-
72
- self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
73
- self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
74
-
75
- def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
76
- timesteps_proj = self.time_proj(timestep)
77
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
78
-
79
- if self.use_additional_conditions:
80
- resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
81
- resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
82
- aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
83
- aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
84
- conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
85
- else:
86
- conditioning = timesteps_emb
87
-
88
- return conditioning
89
-
90
- class AdaLayerNormSingle(nn.Module):
91
- r"""
92
- Norm layer adaptive layer norm single (adaLN-single).
93
-
94
- As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
95
-
96
- Parameters:
97
- embedding_dim (`int`): The size of each embedding vector.
98
- use_additional_conditions (`bool`): To use additional conditions for normalization or not.
99
  """
100
 
101
- def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
102
  super().__init__()
103
-
104
- self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
105
- embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
106
- )
107
-
108
- self.silu = nn.SiLU()
109
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
110
-
111
- def forward(
112
- self,
113
- timestep: torch.Tensor,
114
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
115
- batch_size: Optional[int] = None,
116
- hidden_dtype: Optional[torch.dtype] = None,
117
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
118
- # No modulation happening here.
119
- embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
120
- return self.linear(self.silu(embedded_timestep)), embedded_timestep
121
-
122
 
123
  class TimePositionalEncoding(nn.Module):
124
  def __init__(
@@ -229,9 +181,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
229
  # motion module kwargs
230
  motion_module_type = "VanillaGrid",
231
  motion_module_kwargs = None,
 
 
232
 
233
  # time position encoding
234
- time_position_encoding_before_transformer = False
 
 
 
235
  ):
236
  super().__init__()
237
  self.use_linear_projection = use_linear_projection
@@ -320,6 +277,35 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
320
  attention_type=attention_type,
321
  motion_module_type=motion_module_type,
322
  motion_module_kwargs=motion_module_kwargs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  )
324
  for d in range(num_layers)
325
  ]
@@ -346,6 +332,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
346
  kvcompression=False if d < 14 else True,
347
  motion_module_type=motion_module_type,
348
  motion_module_kwargs=motion_module_kwargs,
 
 
349
  )
350
  for d in range(num_layers)
351
  ]
@@ -369,6 +357,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
369
  norm_elementwise_affine=norm_elementwise_affine,
370
  norm_eps=norm_eps,
371
  attention_type=attention_type,
 
 
372
  )
373
  for d in range(num_layers)
374
  ]
@@ -438,8 +428,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
438
  self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
439
 
440
  self.caption_projection = None
 
441
  if caption_channels is not None:
442
  self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
 
 
443
 
444
  self.gradient_checkpointing = False
445
 
@@ -456,12 +449,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
456
  hidden_states: torch.Tensor,
457
  inpaint_latents: torch.Tensor = None,
458
  encoder_hidden_states: Optional[torch.Tensor] = None,
 
459
  timestep: Optional[torch.LongTensor] = None,
460
  added_cond_kwargs: Dict[str, torch.Tensor] = None,
461
  class_labels: Optional[torch.LongTensor] = None,
462
  cross_attention_kwargs: Dict[str, Any] = None,
463
  attention_mask: Optional[torch.Tensor] = None,
464
  encoder_attention_mask: Optional[torch.Tensor] = None,
 
465
  return_dict: bool = True,
466
  ):
467
  """
@@ -520,6 +515,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
520
  attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
521
  attention_mask = attention_mask.unsqueeze(1)
522
 
 
 
523
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
524
  if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
525
  encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
@@ -560,6 +557,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
560
  encoder_hidden_states = self.caption_projection(encoder_hidden_states)
561
  encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
562
 
 
 
 
 
 
 
 
563
  skips = []
564
  skip_index = 0
565
  for index, block in enumerate(self.transformer_blocks):
@@ -590,7 +594,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
590
  args = {
591
  "basic": [],
592
  "motionmodule": [video_length, height, width],
593
- "selfattentiontemporal": [video_length, height, width],
 
594
  "kvcompression_motionmodule": [video_length, height, width],
595
  }[self.basic_block_type]
596
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -609,7 +614,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
609
  kwargs = {
610
  "basic": {},
611
  "motionmodule": {"num_frames":video_length, "height":height, "width":width},
612
- "selfattentiontemporal": {"num_frames":video_length, "height":height, "width":width},
 
613
  "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
614
  }[self.basic_block_type]
615
  hidden_states = block(
 
15
  import math
16
  import os
17
  from dataclasses import dataclass
18
+ from typing import Any, Dict, Optional, Tuple
19
 
20
  import numpy as np
21
  import torch
22
  import torch.nn.functional as F
23
  import torch.nn.init as init
24
  from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.attention import BasicTransformerBlock, FeedForward
26
+ from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
27
+ TimestepEmbedding, Timesteps)
28
  from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
29
  from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.models.normalization import AdaLayerNormContinuous
31
+ from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
32
+ logging)
33
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
34
  from einops import rearrange
35
  from torch import nn
 
36
 
37
  from .attention import (SelfAttentionTemporalTransformerBlock,
38
  TemporalTransformerBlock)
39
+ from .norm import AdaLayerNormSingle
40
+ from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D,
41
+ TemporalUpsampler3D, UnPatch1D)
42
 
43
  try:
44
  from diffusers.models.embeddings import PixArtAlphaTextProjection
 
52
  p.detach().zero_()
53
  return module
54
 
 
 
 
55
 
56
+ class CLIPProjection(nn.Module):
 
57
  """
58
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
59
 
60
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """
62
 
63
+ def __init__(self, in_features, hidden_size, num_tokens=120):
64
  super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
66
+ self.act_1 = nn.GELU(approximate="tanh")
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
68
+ self.linear_2 = zero_module(self.linear_2)
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
 
 
 
 
 
 
 
 
 
 
74
 
75
  class TimePositionalEncoding(nn.Module):
76
  def __init__(
 
181
  # motion module kwargs
182
  motion_module_type = "VanillaGrid",
183
  motion_module_kwargs = None,
184
+ motion_module_kwargs_odd = None,
185
+ motion_module_kwargs_even = None,
186
 
187
  # time position encoding
188
+ time_position_encoding_before_transformer = False,
189
+
190
+ qk_norm = False,
191
+ after_norm = False,
192
  ):
193
  super().__init__()
194
  self.use_linear_projection = use_linear_projection
 
277
  attention_type=attention_type,
278
  motion_module_type=motion_module_type,
279
  motion_module_kwargs=motion_module_kwargs,
280
+ qk_norm=qk_norm,
281
+ after_norm=after_norm,
282
+ )
283
+ for d in range(num_layers)
284
+ ]
285
+ )
286
+ elif self.basic_block_type == "global_motionmodule":
287
+ self.transformer_blocks = nn.ModuleList(
288
+ [
289
+ TemporalTransformerBlock(
290
+ inner_dim,
291
+ num_attention_heads,
292
+ attention_head_dim,
293
+ dropout=dropout,
294
+ cross_attention_dim=cross_attention_dim,
295
+ activation_fn=activation_fn,
296
+ num_embeds_ada_norm=num_embeds_ada_norm,
297
+ attention_bias=attention_bias,
298
+ only_cross_attention=only_cross_attention,
299
+ double_self_attention=double_self_attention,
300
+ upcast_attention=upcast_attention,
301
+ norm_type=norm_type,
302
+ norm_elementwise_affine=norm_elementwise_affine,
303
+ norm_eps=norm_eps,
304
+ attention_type=attention_type,
305
+ motion_module_type=motion_module_type,
306
+ motion_module_kwargs=motion_module_kwargs_even if d % 2 == 0 else motion_module_kwargs_odd,
307
+ qk_norm=qk_norm,
308
+ after_norm=after_norm,
309
  )
310
  for d in range(num_layers)
311
  ]
 
332
  kvcompression=False if d < 14 else True,
333
  motion_module_type=motion_module_type,
334
  motion_module_kwargs=motion_module_kwargs,
335
+ qk_norm=qk_norm,
336
+ after_norm=after_norm,
337
  )
338
  for d in range(num_layers)
339
  ]
 
357
  norm_elementwise_affine=norm_elementwise_affine,
358
  norm_eps=norm_eps,
359
  attention_type=attention_type,
360
+ qk_norm=qk_norm,
361
+ after_norm=after_norm,
362
  )
363
  for d in range(num_layers)
364
  ]
 
428
  self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
429
 
430
  self.caption_projection = None
431
+ self.clip_projection = None
432
  if caption_channels is not None:
433
  self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
434
+ if in_channels == 12:
435
+ self.clip_projection = CLIPProjection(in_features=768, hidden_size=inner_dim * 8)
436
 
437
  self.gradient_checkpointing = False
438
 
 
449
  hidden_states: torch.Tensor,
450
  inpaint_latents: torch.Tensor = None,
451
  encoder_hidden_states: Optional[torch.Tensor] = None,
452
+ clip_encoder_hidden_states: Optional[torch.Tensor] = None,
453
  timestep: Optional[torch.LongTensor] = None,
454
  added_cond_kwargs: Dict[str, torch.Tensor] = None,
455
  class_labels: Optional[torch.LongTensor] = None,
456
  cross_attention_kwargs: Dict[str, Any] = None,
457
  attention_mask: Optional[torch.Tensor] = None,
458
  encoder_attention_mask: Optional[torch.Tensor] = None,
459
+ clip_attention_mask: Optional[torch.Tensor] = None,
460
  return_dict: bool = True,
461
  ):
462
  """
 
515
  attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
516
  attention_mask = attention_mask.unsqueeze(1)
517
 
518
+ if clip_attention_mask is not None:
519
+ encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1)
520
  # convert encoder_attention_mask to a bias the same way we do for attention_mask
521
  if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
522
  encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
 
557
  encoder_hidden_states = self.caption_projection(encoder_hidden_states)
558
  encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
559
 
560
+ if clip_encoder_hidden_states is not None and encoder_hidden_states is not None:
561
+ batch_size = hidden_states.shape[0]
562
+ clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
563
+ clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
564
+
565
+ encoder_hidden_states = torch.cat([encoder_hidden_states, clip_encoder_hidden_states], dim = 1)
566
+
567
  skips = []
568
  skip_index = 0
569
  for index, block in enumerate(self.transformer_blocks):
 
594
  args = {
595
  "basic": [],
596
  "motionmodule": [video_length, height, width],
597
+ "global_motionmodule": [video_length, height, width],
598
+ "selfattentiontemporal": [],
599
  "kvcompression_motionmodule": [video_length, height, width],
600
  }[self.basic_block_type]
601
  hidden_states = torch.utils.checkpoint.checkpoint(
 
614
  kwargs = {
615
  "basic": {},
616
  "motionmodule": {"num_frames":video_length, "height":height, "width":width},
617
+ "global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
618
+ "selfattentiontemporal": {},
619
  "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
620
  }[self.basic_block_type]
621
  hidden_states = block(
easyanimate/pipeline/pipeline_easyanimate.py CHANGED
@@ -578,7 +578,7 @@ class EasyAnimatePipeline(DiffusionPipeline):
578
 
579
  def decode_latents(self, latents):
580
  video_length = latents.shape[2]
581
- latents = 1 / 0.18215 * latents
582
  if self.vae.quant_conv.weight.ndim==5:
583
  mini_batch_encoder = self.vae.mini_batch_encoder
584
  mini_batch_decoder = self.vae.mini_batch_decoder
 
578
 
579
  def decode_latents(self, latents):
580
  video_length = latents.shape[2]
581
+ latents = 1 / self.vae.config.scaling_factor * latents
582
  if self.vae.quant_conv.weight.ndim==5:
583
  mini_batch_encoder = self.vae.mini_batch_encoder
584
  mini_batch_decoder = self.vae.mini_batch_decoder
easyanimate/pipeline/pipeline_easyanimate_inpaint.py CHANGED
@@ -15,13 +15,16 @@
15
  import html
16
  import inspect
17
  import re
 
18
  import copy
19
  import urllib.parse as ul
20
  from dataclasses import dataclass
 
21
  from typing import Callable, List, Optional, Tuple, Union
22
 
23
  import numpy as np
24
  import torch
 
25
  from diffusers import DiffusionPipeline, ImagePipelineOutput
26
  from diffusers.image_processor import VaeImageProcessor
27
  from diffusers.models import AutoencoderKL
@@ -33,6 +36,7 @@ from diffusers.utils.torch_utils import randn_tensor
33
  from einops import rearrange
34
  from tqdm import tqdm
35
  from transformers import T5EncoderModel, T5Tokenizer
 
36
 
37
  from ..models.transformer3d import Transformer3DModel
38
 
@@ -109,11 +113,15 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
109
  vae: AutoencoderKL,
110
  transformer: Transformer3DModel,
111
  scheduler: DPMSolverMultistepScheduler,
 
 
112
  ):
113
  super().__init__()
114
 
115
  self.register_modules(
116
- tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
 
 
117
  )
118
 
119
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
@@ -503,41 +511,64 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
503
  return_video_latents=False,
504
  ):
505
  if self.vae.quant_conv.weight.ndim==5:
506
- shape = (batch_size, num_channels_latents, int(video_length // 5 * 2) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
507
  else:
508
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
509
  if isinstance(generator, list) and len(generator) != batch_size:
510
  raise ValueError(
511
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513
  )
514
-
515
  if return_video_latents or (latents is None and not is_strength_max):
516
- video = video.to(device=device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
- if video.shape[1] == 4:
519
- video_latents = video
520
  else:
521
- video_length = video.shape[2]
522
- video = rearrange(video, "b c f h w -> (b f) c h w")
523
- video_latents = self._encode_vae_image(image=video, generator=generator)
524
- video_latents = rearrange(video_latents, "(b f) c h w -> b c f h w", f=video_length)
525
- video_latents = video_latents.repeat(batch_size // video_latents.shape[0], 1, 1, 1, 1)
 
 
 
526
 
527
  if latents is None:
528
- rand_device = "cpu" if device.type == "mps" else device
529
-
530
- noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
531
  # if strength is 1. then initialise the latents to noise, else initial to image + noise
532
  latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
 
 
533
  else:
534
  noise = latents.to(device)
535
- if latents.shape != shape:
536
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
537
- latents = latents.to(device)
538
 
539
  # scale the initial noise by the standard deviation required by the scheduler
540
- latents = latents * self.scheduler.init_noise_sigma
541
  outputs = (latents,)
542
 
543
  if return_noise:
@@ -548,33 +579,61 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
548
 
549
  return outputs
550
 
551
- def decode_latents(self, latents):
552
- video_length = latents.shape[2]
553
- latents = 1 / 0.18215 * latents
554
- if self.vae.quant_conv.weight.ndim==5:
555
- mini_batch_decoder = 2
556
- # Decoder
557
- video = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  for i in range(0, latents.shape[2], mini_batch_decoder):
559
  with torch.no_grad():
560
  start_index = i
561
  end_index = i + mini_batch_decoder
562
  latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
563
- video.append(latents_bs)
564
-
565
- # Smooth
566
- mini_batch_encoder = 5
567
- video = torch.cat(video, 2).cpu()
568
- for i in range(mini_batch_encoder, video.shape[2], mini_batch_encoder):
569
- origin_before = copy.deepcopy(video[:, :, i - 1, :, :])
570
- origin_after = copy.deepcopy(video[:, :, i, :, :])
571
-
572
- video[:, :, i - 1, :, :] = origin_before * 0.75 + origin_after * 0.25
573
- video[:, :, i, :, :] = origin_before * 0.25 + origin_after * 0.75
 
 
 
 
 
 
 
 
 
 
 
574
  video = video.clamp(-1, 1)
 
575
  else:
576
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
577
- # video = self.vae.decode(latents).sample
578
  video = []
579
  for frame_idx in tqdm(range(latents.shape[0])):
580
  video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
@@ -599,6 +658,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
599
 
600
  return image_latents
601
 
 
 
 
 
 
 
 
 
 
 
602
  def prepare_mask_latents(
603
  self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
604
  ):
@@ -610,19 +679,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
610
  mask = mask.to(device=device, dtype=self.vae.dtype)
611
  if self.vae.quant_conv.weight.ndim==5:
612
  bs = 1
 
613
  new_mask = []
614
- for i in range(0, mask.shape[0], bs):
615
- mini_batch = 5
616
- new_mask_mini_batch = []
617
- for j in range(0, mask.shape[2], mini_batch):
618
- mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
619
  mask_bs = self.vae.encode(mask_bs)[0]
620
  mask_bs = mask_bs.sample()
621
- new_mask_mini_batch.append(mask_bs)
622
- new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
623
- new_mask.append(new_mask_mini_batch)
 
 
 
 
 
 
 
 
624
  mask = torch.cat(new_mask, dim = 0)
625
- mask = mask * 0.1825
626
 
627
  else:
628
  if mask.shape[1] == 4:
@@ -636,19 +712,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
636
  masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
637
  if self.vae.quant_conv.weight.ndim==5:
638
  bs = 1
 
639
  new_mask_pixel_values = []
640
- for i in range(0, masked_image.shape[0], bs):
641
- mini_batch = 5
642
- new_mask_pixel_values_mini_batch = []
643
- for j in range(0, masked_image.shape[2], mini_batch):
644
- mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
645
  mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
646
  mask_pixel_values_bs = mask_pixel_values_bs.sample()
647
- new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
648
- new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
649
- new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
 
 
 
 
 
 
 
 
650
  masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
651
- masked_image_latents = masked_image_latents * 0.1825
652
 
653
  else:
654
  if masked_image.shape[1] == 4:
@@ -693,7 +776,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
693
  callback_steps: int = 1,
694
  clean_caption: bool = True,
695
  mask_feature: bool = True,
696
- max_sequence_length: int = 120
 
 
697
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
698
  """
699
  Function invoked when calling the pipeline for generation.
@@ -767,6 +852,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
767
  # 1. Check inputs. Raise error if not correct
768
  height = height or self.transformer.config.sample_size * self.vae_scale_factor
769
  width = width or self.transformer.config.sample_size * self.vae_scale_factor
 
 
770
 
771
  # 2. Default height and width to transformer
772
  if prompt is not None and isinstance(prompt, str):
@@ -806,11 +893,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
806
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
807
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
808
 
809
- # 4. Prepare timesteps
810
  self.scheduler.set_timesteps(num_inference_steps, device=device)
811
- timesteps = self.scheduler.timesteps
 
 
812
  # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
813
- latent_timestep = timesteps[:1].repeat(batch_size)
814
  # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
815
  is_strength_max = strength == 1.0
816
 
@@ -825,7 +914,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
825
  # Prepare latent variables
826
  num_channels_latents = self.vae.config.latent_channels
827
  num_channels_transformer = self.transformer.config.in_channels
828
- return_image_latents = num_channels_transformer == 4
829
 
830
  # 5. Prepare latents.
831
  latents_outputs = self.prepare_latents(
@@ -857,30 +946,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
857
  mask_condition = mask_condition.to(dtype=torch.float32)
858
  mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
859
 
860
- if masked_video_latents is None:
861
- masked_video = init_video * (mask_condition < 0.5) + torch.ones_like(init_video) * (mask_condition > 0.5) * -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  else:
863
- masked_video = masked_video_latents
864
-
865
- mask, masked_video_latents = self.prepare_mask_latents(
866
- mask_condition,
867
- masked_video,
868
- batch_size,
869
- height,
870
- width,
871
- prompt_embeds.dtype,
872
- device,
873
- generator,
874
- do_classifier_free_guidance,
875
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876
  else:
877
- mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
878
- masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
879
 
880
  # Check that sizes of mask, masked image and latents match
881
  if num_channels_transformer == 12:
882
  # default case for runwayml/stable-diffusion-inpainting
883
- num_channels_mask = mask.shape[1]
884
  num_channels_masked_image = masked_video_latents.shape[1]
885
  if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
886
  raise ValueError(
@@ -890,12 +1032,12 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
890
  f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
891
  " `pipeline.transformer` or your `mask_image` or `image` input."
892
  )
893
- elif num_channels_transformer == 4:
894
  raise ValueError(
895
  f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
896
  )
897
 
898
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
899
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
900
 
901
  # 6.1 Prepare micro-conditions.
@@ -912,21 +1054,25 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
912
 
913
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
914
 
915
- # 7. Denoising loop
916
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
917
 
 
 
 
918
  with self.progress_bar(total=num_inference_steps) as progress_bar:
919
  for i, t in enumerate(timesteps):
920
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
921
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
922
 
923
- if num_channels_transformer == 12:
924
- mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
925
- masked_video_latents_input = (
926
- torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
927
- )
928
- inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1)
929
-
930
  current_timestep = t
931
  if not torch.is_tensor(current_timestep):
932
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
@@ -949,7 +1095,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
949
  encoder_attention_mask=prompt_attention_mask,
950
  timestep=current_timestep,
951
  added_cond_kwargs=added_cond_kwargs,
952
- inpaint_latents=inpaint_latents.to(latent_model_input.dtype),
 
 
953
  return_dict=False,
954
  )[0]
955
 
@@ -964,6 +1112,17 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
964
  # compute previous image: x_t -> x_t-1
965
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
966
 
 
 
 
 
 
 
 
 
 
 
 
967
  # call the callback, if provided
968
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
969
  progress_bar.update()
@@ -971,9 +1130,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline):
971
  step_idx = i // getattr(self.scheduler, "order", 1)
972
  callback(step_idx, t, latents)
973
 
 
 
 
 
974
  # Post-processing
975
  video = self.decode_latents(latents)
976
-
 
 
 
977
  # Convert to tensor
978
  if output_type == "latent":
979
  video = torch.from_numpy(video)
 
15
  import html
16
  import inspect
17
  import re
18
+ import gc
19
  import copy
20
  import urllib.parse as ul
21
  from dataclasses import dataclass
22
+ from PIL import Image
23
  from typing import Callable, List, Optional, Tuple, Union
24
 
25
  import numpy as np
26
  import torch
27
+ import torch.nn.functional as F
28
  from diffusers import DiffusionPipeline, ImagePipelineOutput
29
  from diffusers.image_processor import VaeImageProcessor
30
  from diffusers.models import AutoencoderKL
 
36
  from einops import rearrange
37
  from tqdm import tqdm
38
  from transformers import T5EncoderModel, T5Tokenizer
39
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
40
 
41
  from ..models.transformer3d import Transformer3DModel
42
 
 
113
  vae: AutoencoderKL,
114
  transformer: Transformer3DModel,
115
  scheduler: DPMSolverMultistepScheduler,
116
+ clip_image_processor:CLIPImageProcessor = None,
117
+ clip_image_encoder:CLIPVisionModelWithProjection = None,
118
  ):
119
  super().__init__()
120
 
121
  self.register_modules(
122
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,
123
+ scheduler=scheduler,
124
+ clip_image_processor=clip_image_processor, clip_image_encoder=clip_image_encoder,
125
  )
126
 
127
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
511
  return_video_latents=False,
512
  ):
513
  if self.vae.quant_conv.weight.ndim==5:
514
+ mini_batch_encoder = self.vae.mini_batch_encoder
515
+ mini_batch_decoder = self.vae.mini_batch_decoder
516
+ shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
517
  else:
518
  shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
519
+
520
  if isinstance(generator, list) and len(generator) != batch_size:
521
  raise ValueError(
522
  f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
523
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
524
  )
525
+
526
  if return_video_latents or (latents is None and not is_strength_max):
527
+ video = video.to(device=device, dtype=self.vae.dtype)
528
+ if self.vae.quant_conv.weight.ndim==5:
529
+ bs = 1
530
+ mini_batch_encoder = self.vae.mini_batch_encoder
531
+ new_video = []
532
+ if self.vae.slice_compression_vae:
533
+ for i in range(0, video.shape[0], bs):
534
+ video_bs = video[i : i + bs]
535
+ video_bs = self.vae.encode(video_bs)[0]
536
+ video_bs = video_bs.sample()
537
+ new_video.append(video_bs)
538
+ else:
539
+ for i in range(0, video.shape[0], bs):
540
+ new_video_mini_batch = []
541
+ for j in range(0, video.shape[2], mini_batch_encoder):
542
+ video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :]
543
+ video_bs = self.vae.encode(video_bs)[0]
544
+ video_bs = video_bs.sample()
545
+ new_video_mini_batch.append(video_bs)
546
+ new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2)
547
+ new_video.append(new_video_mini_batch)
548
+ video = torch.cat(new_video, dim = 0)
549
+ video = video * self.vae.config.scaling_factor
550
 
 
 
551
  else:
552
+ if video.shape[1] == 4:
553
+ video = video
554
+ else:
555
+ video_length = video.shape[2]
556
+ video = rearrange(video, "b c f h w -> (b f) c h w")
557
+ video = self._encode_vae_image(video, generator=generator)
558
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
559
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
560
 
561
  if latents is None:
562
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
563
  # if strength is 1. then initialise the latents to noise, else initial to image + noise
564
  latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
565
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
566
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
567
  else:
568
  noise = latents.to(device)
569
+ latents = noise * self.scheduler.init_noise_sigma
 
 
570
 
571
  # scale the initial noise by the standard deviation required by the scheduler
 
572
  outputs = (latents,)
573
 
574
  if return_noise:
 
579
 
580
  return outputs
581
 
582
+ def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
583
+ if video.size()[2] <= mini_batch_encoder:
584
+ return video
585
+ prefix_index_before = mini_batch_encoder // 2
586
+ prefix_index_after = mini_batch_encoder - prefix_index_before
587
+ pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
588
+
589
+ if self.vae.slice_compression_vae:
590
+ latents = self.vae.encode(pixel_values)[0]
591
+ latents = latents.sample()
592
+ else:
593
+ new_pixel_values = []
594
+ for i in range(0, pixel_values.shape[2], mini_batch_encoder):
595
+ with torch.no_grad():
596
+ pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
597
+ pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
598
+ pixel_values_bs = pixel_values_bs.sample()
599
+ new_pixel_values.append(pixel_values_bs)
600
+ latents = torch.cat(new_pixel_values, dim = 2)
601
+
602
+ if self.vae.slice_compression_vae:
603
+ middle_video = self.vae.decode(latents)[0]
604
+ else:
605
+ middle_video = []
606
  for i in range(0, latents.shape[2], mini_batch_decoder):
607
  with torch.no_grad():
608
  start_index = i
609
  end_index = i + mini_batch_decoder
610
  latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
611
+ middle_video.append(latents_bs)
612
+ middle_video = torch.cat(middle_video, 2)
613
+ video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
614
+ return video
615
+
616
+ def decode_latents(self, latents):
617
+ video_length = latents.shape[2]
618
+ latents = 1 / self.vae.config.scaling_factor * latents
619
+ if self.vae.quant_conv.weight.ndim==5:
620
+ mini_batch_encoder = self.vae.mini_batch_encoder
621
+ mini_batch_decoder = self.vae.mini_batch_decoder
622
+ if self.vae.slice_compression_vae:
623
+ video = self.vae.decode(latents)[0]
624
+ else:
625
+ video = []
626
+ for i in range(0, latents.shape[2], mini_batch_decoder):
627
+ with torch.no_grad():
628
+ start_index = i
629
+ end_index = i + mini_batch_decoder
630
+ latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
631
+ video.append(latents_bs)
632
+ video = torch.cat(video, 2)
633
  video = video.clamp(-1, 1)
634
+ video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
635
  else:
636
  latents = rearrange(latents, "b c f h w -> (b f) c h w")
 
637
  video = []
638
  for frame_idx in tqdm(range(latents.shape[0])):
639
  video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
 
658
 
659
  return image_latents
660
 
661
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
662
+ def get_timesteps(self, num_inference_steps, strength, device):
663
+ # get the original timestep using init_timestep
664
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
665
+
666
+ t_start = max(num_inference_steps - init_timestep, 0)
667
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
668
+
669
+ return timesteps, num_inference_steps - t_start
670
+
671
  def prepare_mask_latents(
672
  self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
673
  ):
 
679
  mask = mask.to(device=device, dtype=self.vae.dtype)
680
  if self.vae.quant_conv.weight.ndim==5:
681
  bs = 1
682
+ mini_batch_encoder = self.vae.mini_batch_encoder
683
  new_mask = []
684
+ if self.vae.slice_compression_vae:
685
+ for i in range(0, mask.shape[0], bs):
686
+ mask_bs = mask[i : i + bs]
 
 
687
  mask_bs = self.vae.encode(mask_bs)[0]
688
  mask_bs = mask_bs.sample()
689
+ new_mask.append(mask_bs)
690
+ else:
691
+ for i in range(0, mask.shape[0], bs):
692
+ new_mask_mini_batch = []
693
+ for j in range(0, mask.shape[2], mini_batch_encoder):
694
+ mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :]
695
+ mask_bs = self.vae.encode(mask_bs)[0]
696
+ mask_bs = mask_bs.sample()
697
+ new_mask_mini_batch.append(mask_bs)
698
+ new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
699
+ new_mask.append(new_mask_mini_batch)
700
  mask = torch.cat(new_mask, dim = 0)
701
+ mask = mask * self.vae.config.scaling_factor
702
 
703
  else:
704
  if mask.shape[1] == 4:
 
712
  masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
713
  if self.vae.quant_conv.weight.ndim==5:
714
  bs = 1
715
+ mini_batch_encoder = self.vae.mini_batch_encoder
716
  new_mask_pixel_values = []
717
+ if self.vae.slice_compression_vae:
718
+ for i in range(0, masked_image.shape[0], bs):
719
+ mask_pixel_values_bs = masked_image[i : i + bs]
 
 
720
  mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
721
  mask_pixel_values_bs = mask_pixel_values_bs.sample()
722
+ new_mask_pixel_values.append(mask_pixel_values_bs)
723
+ else:
724
+ for i in range(0, masked_image.shape[0], bs):
725
+ new_mask_pixel_values_mini_batch = []
726
+ for j in range(0, masked_image.shape[2], mini_batch_encoder):
727
+ mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :]
728
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
729
+ mask_pixel_values_bs = mask_pixel_values_bs.sample()
730
+ new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
731
+ new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
732
+ new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
733
  masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
734
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
735
 
736
  else:
737
  if masked_image.shape[1] == 4:
 
776
  callback_steps: int = 1,
777
  clean_caption: bool = True,
778
  mask_feature: bool = True,
779
+ max_sequence_length: int = 120,
780
+ clip_image: Image = None,
781
+ clip_apply_ratio: float = 0.50,
782
  ) -> Union[EasyAnimatePipelineOutput, Tuple]:
783
  """
784
  Function invoked when calling the pipeline for generation.
 
852
  # 1. Check inputs. Raise error if not correct
853
  height = height or self.transformer.config.sample_size * self.vae_scale_factor
854
  width = width or self.transformer.config.sample_size * self.vae_scale_factor
855
+ height = int(height // 16 * 16)
856
+ width = int(width // 16 * 16)
857
 
858
  # 2. Default height and width to transformer
859
  if prompt is not None and isinstance(prompt, str):
 
893
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
894
  prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
895
 
896
+ # 4. set timesteps
897
  self.scheduler.set_timesteps(num_inference_steps, device=device)
898
+ timesteps, num_inference_steps = self.get_timesteps(
899
+ num_inference_steps=num_inference_steps, strength=strength, device=device
900
+ )
901
  # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
902
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
903
  # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
904
  is_strength_max = strength == 1.0
905
 
 
914
  # Prepare latent variables
915
  num_channels_latents = self.vae.config.latent_channels
916
  num_channels_transformer = self.transformer.config.in_channels
917
+ return_image_latents = True # num_channels_transformer == 4
918
 
919
  # 5. Prepare latents.
920
  latents_outputs = self.prepare_latents(
 
946
  mask_condition = mask_condition.to(dtype=torch.float32)
947
  mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
948
 
949
+ if num_channels_transformer == 12:
950
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
951
+ if masked_video_latents is None:
952
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
953
+ else:
954
+ masked_video = masked_video_latents
955
+
956
+ mask_latents, masked_video_latents = self.prepare_mask_latents(
957
+ mask_condition_tile,
958
+ masked_video,
959
+ batch_size,
960
+ height,
961
+ width,
962
+ prompt_embeds.dtype,
963
+ device,
964
+ generator,
965
+ do_classifier_free_guidance,
966
+ )
967
+ mask = torch.tile(mask_condition, [1, num_channels_transformer // 3, 1, 1, 1])
968
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
969
+
970
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
971
+ masked_video_latents_input = (
972
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
973
+ )
974
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
975
  else:
976
+ mask = torch.tile(mask_condition, [1, num_channels_transformer, 1, 1, 1])
977
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
978
+
979
+ inpaint_latents = None
980
+ else:
981
+ if num_channels_transformer == 12:
982
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
983
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
984
+
985
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
986
+ masked_video_latents_input = (
987
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
988
+ )
989
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
990
+ else:
991
+ mask = torch.zeros_like(init_video[:, :1])
992
+ mask = torch.tile(mask, [1, num_channels_transformer, 1, 1, 1])
993
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
994
+
995
+ inpaint_latents = None
996
+
997
+ if clip_image is not None:
998
+ inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
999
+ inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
1000
+ clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds
1001
+ clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
1002
+
1003
+ clip_attention_mask = torch.ones([batch_size, 8]).to(latents.device, dtype=latents.dtype)
1004
+ clip_attention_mask_neg = torch.zeros([batch_size, 8]).to(latents.device, dtype=latents.dtype)
1005
+
1006
+ clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if do_classifier_free_guidance else clip_encoder_hidden_states
1007
+ clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if do_classifier_free_guidance else clip_attention_mask
1008
+
1009
+ elif clip_image is None and num_channels_transformer == 12:
1010
+ clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
1011
+
1012
+ clip_attention_mask = torch.zeros([batch_size, 8])
1013
+ clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
1014
+
1015
+ clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if do_classifier_free_guidance else clip_encoder_hidden_states
1016
+ clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if do_classifier_free_guidance else clip_attention_mask
1017
+
1018
  else:
1019
+ clip_encoder_hidden_states_input = None
1020
+ clip_attention_mask_input = None
1021
 
1022
  # Check that sizes of mask, masked image and latents match
1023
  if num_channels_transformer == 12:
1024
  # default case for runwayml/stable-diffusion-inpainting
1025
+ num_channels_mask = mask_latents.shape[1]
1026
  num_channels_masked_image = masked_video_latents.shape[1]
1027
  if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
1028
  raise ValueError(
 
1032
  f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1033
  " `pipeline.transformer` or your `mask_image` or `image` input."
1034
  )
1035
+ elif num_channels_transformer != 4:
1036
  raise ValueError(
1037
  f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
1038
  )
1039
 
1040
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1041
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1042
 
1043
  # 6.1 Prepare micro-conditions.
 
1054
 
1055
  added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
1056
 
1057
+ gc.collect()
1058
+ torch.cuda.empty_cache()
1059
+ torch.cuda.ipc_collect()
1060
 
1061
+ # 10. Denoising loop
1062
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1063
+ self._num_timesteps = len(timesteps)
1064
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1065
  for i, t in enumerate(timesteps):
1066
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1067
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1068
 
1069
+ if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
1070
+ clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
1071
+ clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
1072
+ else:
1073
+ clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
1074
+ clip_attention_mask_actual_input = clip_attention_mask_input
1075
+
1076
  current_timestep = t
1077
  if not torch.is_tensor(current_timestep):
1078
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
 
1095
  encoder_attention_mask=prompt_attention_mask,
1096
  timestep=current_timestep,
1097
  added_cond_kwargs=added_cond_kwargs,
1098
+ inpaint_latents=inpaint_latents,
1099
+ clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
1100
+ clip_attention_mask=clip_attention_mask_actual_input,
1101
  return_dict=False,
1102
  )[0]
1103
 
 
1112
  # compute previous image: x_t -> x_t-1
1113
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1114
 
1115
+ if num_channels_transformer == 4:
1116
+ init_latents_proper = image_latents
1117
+ init_mask = mask
1118
+ if i < len(timesteps) - 1:
1119
+ noise_timestep = timesteps[i + 1]
1120
+ init_latents_proper = self.scheduler.add_noise(
1121
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1122
+ )
1123
+
1124
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1125
+
1126
  # call the callback, if provided
1127
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1128
  progress_bar.update()
 
1130
  step_idx = i // getattr(self.scheduler, "order", 1)
1131
  callback(step_idx, t, latents)
1132
 
1133
+ gc.collect()
1134
+ torch.cuda.empty_cache()
1135
+ torch.cuda.ipc_collect()
1136
+
1137
  # Post-processing
1138
  video = self.decode_latents(latents)
1139
+
1140
+ gc.collect()
1141
+ torch.cuda.empty_cache()
1142
+ torch.cuda.ipc_collect()
1143
  # Convert to tensor
1144
  if output_type == "latent":
1145
  video = torch.from_numpy(video)
easyanimate/ui/ui.py CHANGED
@@ -1,35 +1,40 @@
1
  """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
  """
 
3
  import gc
4
  import json
5
  import os
6
  import random
7
- import base64
8
- import requests
9
- import pkg_resources
10
  from datetime import datetime
11
  from glob import glob
12
 
13
  import gradio as gr
14
- import torch
15
  import numpy as np
 
 
 
16
  from diffusers import (AutoencoderKL, DDIMScheduler,
17
  DPMSolverMultistepScheduler,
18
  EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
19
  PNDMScheduler)
20
- from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
21
  from diffusers.utils.import_utils import is_xformers_available
22
  from omegaconf import OmegaConf
 
23
  from safetensors import safe_open
24
- from transformers import T5EncoderModel, T5Tokenizer
 
25
 
 
 
26
  from easyanimate.models.transformer3d import Transformer3DModel
27
  from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
 
 
28
  from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
29
- from easyanimate.utils.utils import save_videos_grid
30
- from PIL import Image
 
31
 
32
- sample_idx = 0
33
  scheduler_dict = {
34
  "Euler": EulerDiscreteScheduler,
35
  "Euler A": EulerAncestralDiscreteScheduler,
@@ -60,8 +65,8 @@ class EasyAnimateController:
60
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
61
  self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
62
  self.savedir_sample = os.path.join(self.savedir, "sample")
63
- self.edition = "v2"
64
- self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
65
  os.makedirs(self.savedir, exist_ok=True)
66
 
67
  self.diffusion_transformer_list = []
@@ -85,14 +90,14 @@ class EasyAnimateController:
85
  self.weight_dtype = torch.bfloat16
86
 
87
  def refresh_diffusion_transformer(self):
88
- self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
89
 
90
  def refresh_motion_module(self):
91
- motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
92
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
93
 
94
  def refresh_personalized_model(self):
95
- personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
96
  self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
97
 
98
  def update_edition(self, edition):
@@ -100,19 +105,24 @@ class EasyAnimateController:
100
  self.edition = edition
101
  if edition == "v1":
102
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
103
- return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
104
- gr.update(visible=False), gr.update(value=512, minimum=384, maximum=704, step=32), \
105
  gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
106
- else:
107
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
108
- return gr.Dropdown.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
109
- gr.update(visible=True), gr.update(value=672, minimum=128, maximum=1280, step=16), \
110
  gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
 
 
 
 
 
111
 
112
  def update_diffusion_transformer(self, diffusion_transformer_dropdown):
113
  print("Update diffusion transformer")
114
  if diffusion_transformer_dropdown == "none":
115
- return gr.Dropdown.update()
116
  if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
117
  Choosen_AutoencoderKL = AutoencoderKLMagvit
118
  else:
@@ -130,25 +140,42 @@ class EasyAnimateController:
130
  self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
131
 
132
  # Get pipeline
133
- self.pipeline = EasyAnimatePipeline(
134
- vae=self.vae,
135
- text_encoder=self.text_encoder,
136
- tokenizer=self.tokenizer,
137
- transformer=self.transformer,
138
- scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
139
- )
140
- self.pipeline.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  print("Update diffusion transformer done")
142
- return gr.Dropdown.update()
143
 
144
  def update_motion_module(self, motion_module_dropdown):
145
  self.motion_module_path = motion_module_dropdown
146
  print("Update motion module")
147
  if motion_module_dropdown == "none":
148
- return gr.Dropdown.update()
149
  if self.transformer is None:
150
  gr.Info(f"Please select a pretrained model path.")
151
- return gr.Dropdown.update(value=None)
152
  else:
153
  motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
154
  if motion_module_dropdown.endswith(".safetensors"):
@@ -160,16 +187,16 @@ class EasyAnimateController:
160
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
161
  missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
162
  print("Update motion module done.")
163
- return gr.Dropdown.update()
164
 
165
  def update_base_model(self, base_model_dropdown):
166
  self.base_model_path = base_model_dropdown
167
  print("Update base model")
168
  if base_model_dropdown == "none":
169
- return gr.Dropdown.update()
170
  if self.transformer is None:
171
  gr.Info(f"Please select a pretrained model path.")
172
- return gr.Dropdown.update(value=None)
173
  else:
174
  base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
175
  base_model_state_dict = {}
@@ -178,16 +205,16 @@ class EasyAnimateController:
178
  base_model_state_dict[key] = f.get_tensor(key)
179
  self.transformer.load_state_dict(base_model_state_dict, strict=False)
180
  print("Update base done")
181
- return gr.Dropdown.update()
182
 
183
  def update_lora_model(self, lora_model_dropdown):
184
  print("Update lora model")
185
  if lora_model_dropdown == "none":
186
  self.lora_model_path = "none"
187
- return gr.Dropdown.update()
188
  lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
189
  self.lora_model_path = lora_model_dropdown
190
- return gr.Dropdown.update()
191
 
192
  def generate(
193
  self,
@@ -200,15 +227,24 @@ class EasyAnimateController:
200
  negative_prompt_textbox,
201
  sampler_dropdown,
202
  sample_step_slider,
 
203
  width_slider,
204
  height_slider,
205
- is_image,
 
206
  length_slider,
 
 
207
  cfg_scale_slider,
 
 
208
  seed_textbox,
209
  is_api = False,
210
  ):
211
- global sample_idx
 
 
 
212
  if self.transformer is None:
213
  raise gr.Error(f"Please select a pretrained model path.")
214
 
@@ -221,6 +257,39 @@ class EasyAnimateController:
221
  if self.lora_model_path != lora_model_dropdown:
222
  print("Update lora model")
223
  self.update_lora_model(lora_model_dropdown)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
226
 
@@ -235,16 +304,98 @@ class EasyAnimateController:
235
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
236
 
237
  try:
238
- sample = self.pipeline(
239
- prompt_textbox,
240
- negative_prompt = negative_prompt_textbox,
241
- num_inference_steps = sample_step_slider,
242
- guidance_scale = cfg_scale_slider,
243
- width = width_slider,
244
- height = height_slider,
245
- video_length = length_slider if not is_image else 1,
246
- generator = generator
247
- ).videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  except Exception as e:
249
  gc.collect()
250
  torch.cuda.empty_cache()
@@ -254,7 +405,11 @@ class EasyAnimateController:
254
  if is_api:
255
  return "", f"Error. error information is {str(e)}"
256
  else:
257
- return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
 
 
 
 
258
 
259
  # lora part
260
  if self.lora_model_path != "none":
@@ -296,7 +451,10 @@ class EasyAnimateController:
296
  if is_api:
297
  return save_sample_path, "Success"
298
  else:
299
- return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
 
 
 
300
  else:
301
  save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
302
  save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
@@ -304,7 +462,10 @@ class EasyAnimateController:
304
  if is_api:
305
  return save_sample_path, "Success"
306
  else:
307
- return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
 
 
 
308
 
309
 
310
  def ui():
@@ -325,24 +486,24 @@ def ui():
325
  with gr.Column(variant="panel"):
326
  gr.Markdown(
327
  """
328
- ### 1. EasyAnimate Edition (select easyanimate edition first).
329
  """
330
  )
331
  with gr.Row():
332
  easyanimate_edition_dropdown = gr.Dropdown(
333
- label="The config of EasyAnimate Edition",
334
- choices=["v1", "v2"],
335
- value="v2",
336
  interactive=True,
337
  )
338
  gr.Markdown(
339
  """
340
- ### 2. Model checkpoints (select pretrained model path).
341
  """
342
  )
343
  with gr.Row():
344
  diffusion_transformer_dropdown = gr.Dropdown(
345
- label="Pretrained Model Path",
346
  choices=controller.diffusion_transformer_list,
347
  value="none",
348
  interactive=True,
@@ -356,12 +517,12 @@ def ui():
356
  diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
357
  def refresh_diffusion_transformer():
358
  controller.refresh_diffusion_transformer()
359
- return gr.Dropdown.update(choices=controller.diffusion_transformer_list)
360
  diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
361
 
362
  with gr.Row():
363
  motion_module_dropdown = gr.Dropdown(
364
- label="Select motion module",
365
  choices=controller.motion_module_list,
366
  value="none",
367
  interactive=True,
@@ -371,78 +532,139 @@ def ui():
371
  motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
372
  def update_motion_module():
373
  controller.refresh_motion_module()
374
- return gr.Dropdown.update(choices=controller.motion_module_list)
375
  motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
376
 
377
  base_model_dropdown = gr.Dropdown(
378
- label="Select base Dreambooth model (optional)",
379
  choices=controller.personalized_model_list,
380
  value="none",
381
  interactive=True,
382
  )
383
 
384
  lora_model_dropdown = gr.Dropdown(
385
- label="Select LoRA model (optional)",
386
  choices=["none"] + controller.personalized_model_list,
387
  value="none",
388
  interactive=True,
389
  )
390
 
391
- lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
392
 
393
  personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
394
  def update_personalized_model():
395
  controller.refresh_personalized_model()
396
  return [
397
- gr.Dropdown.update(choices=controller.personalized_model_list),
398
- gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
399
  ]
400
  personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
401
 
402
  with gr.Column(variant="panel"):
403
  gr.Markdown(
404
  """
405
- ### 3. Configs for Generation.
406
  """
407
  )
408
 
409
- prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
410
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
411
 
412
  with gr.Row():
413
  with gr.Column():
414
  with gr.Row():
415
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
416
- sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=100, step=1)
417
 
418
- width_slider = gr.Slider(label="Width", value=672, minimum=128, maximum=1280, step=16)
419
- height_slider = gr.Slider(label="Height", value=384, minimum=128, maximum=1280, step=16)
420
- with gr.Row():
421
- is_image = gr.Checkbox(False, label="Generate Image")
422
- length_slider = gr.Slider(label="Animation length", value=144, minimum=9, maximum=144, step=9)
423
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424
 
425
  with gr.Row():
426
- seed_textbox = gr.Textbox(label="Seed", value=43)
427
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
428
- seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
 
 
 
 
429
 
430
- generate_button = gr.Button(value="Generate", variant='primary')
431
 
432
  with gr.Column():
433
- result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
434
- result_video = gr.Video(label="Generated Animation", interactive=False)
435
  infer_progress = gr.Textbox(
436
- label="Generation Info",
437
  value="No task currently",
438
  interactive=False
439
  )
440
 
441
- is_image.change(
442
- lambda x: gr.update(visible=not x),
443
- inputs=[is_image],
444
- outputs=[length_slider],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  )
 
446
  easyanimate_edition_dropdown.change(
447
  fn=controller.update_edition,
448
  inputs=[easyanimate_edition_dropdown],
@@ -451,7 +673,6 @@ def ui():
451
  diffusion_transformer_dropdown,
452
  motion_module_dropdown,
453
  motion_module_refresh_button,
454
- is_image,
455
  width_slider,
456
  height_slider,
457
  length_slider,
@@ -469,11 +690,17 @@ def ui():
469
  negative_prompt_textbox,
470
  sampler_dropdown,
471
  sample_step_slider,
 
472
  width_slider,
473
  height_slider,
474
- is_image,
 
475
  length_slider,
 
 
476
  cfg_scale_slider,
 
 
477
  seed_textbox,
478
  ],
479
  outputs=[result_image, result_video, infer_progress]
@@ -483,11 +710,18 @@ def ui():
483
 
484
  class EasyAnimateController_Modelscope:
485
  def __init__(self, edition, config_path, model_name, savedir_sample):
486
- # Config and model path
487
- weight_dtype = torch.bfloat16
488
- self.savedir_sample = savedir_sample
 
 
 
 
 
 
489
  os.makedirs(self.savedir_sample, exist_ok=True)
490
 
 
491
  self.edition = edition
492
  self.inference_config = OmegaConf.load(config_path)
493
  # Get Transformer
@@ -513,32 +747,107 @@ class EasyAnimateController_Modelscope:
513
  subfolder="text_encoder",
514
  torch_dtype=weight_dtype
515
  )
516
- self.pipeline = EasyAnimatePipeline(
517
- vae=self.vae,
518
- text_encoder=self.text_encoder,
519
- tokenizer=self.tokenizer,
520
- transformer=self.transformer,
521
- scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
522
- )
523
- self.pipeline.enable_model_cpu_offload()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  print("Update diffusion transformer done")
525
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  def generate(
527
  self,
 
 
 
 
 
528
  prompt_textbox,
529
  negative_prompt_textbox,
530
  sampler_dropdown,
531
  sample_step_slider,
 
532
  width_slider,
533
  height_slider,
534
- is_image,
 
535
  length_slider,
536
  cfg_scale_slider,
537
- seed_textbox
 
 
 
538
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
540
 
541
  self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
 
 
 
542
  self.pipeline.to("cuda")
543
 
544
  if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
@@ -546,21 +855,52 @@ class EasyAnimateController_Modelscope:
546
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
547
 
548
  try:
549
- sample = self.pipeline(
550
- prompt_textbox,
551
- negative_prompt = negative_prompt_textbox,
552
- num_inference_steps = sample_step_slider,
553
- guidance_scale = cfg_scale_slider,
554
- width = width_slider,
555
- height = height_slider,
556
- video_length = length_slider if not is_image else 1,
557
- generator = generator
558
- ).videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  except Exception as e:
560
  gc.collect()
561
  torch.cuda.empty_cache()
562
  torch.cuda.ipc_collect()
563
- return gr.Image.update(), gr.Video.update(), f"Error. error information is {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
564
 
565
  if not os.path.exists(self.savedir_sample):
566
  os.makedirs(self.savedir_sample, exist_ok=True)
@@ -578,11 +918,23 @@ class EasyAnimateController_Modelscope:
578
  image = (image * 255).numpy().astype(np.uint8)
579
  image = Image.fromarray(image)
580
  image.save(save_sample_path)
581
- return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
 
 
 
 
 
 
582
  else:
583
  save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
584
  save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
585
- return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
 
 
 
 
 
 
586
 
587
 
588
  def ui_modelscope(edition, config_path, model_name, savedir_sample):
@@ -601,71 +953,197 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample):
601
  """
602
  )
603
  with gr.Column(variant="panel"):
604
- prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
605
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
 
607
  with gr.Row():
608
  with gr.Column():
609
  with gr.Row():
610
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
611
- sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
612
 
613
  if edition == "v1":
614
- width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
615
- height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
616
- with gr.Row():
617
- is_image = gr.Checkbox(False, label="Generate Image", visible=False)
618
- length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
619
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
 
 
 
 
 
 
620
  else:
621
- width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
622
- height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
 
 
 
623
  with gr.Column():
624
  gr.Markdown(
625
  """
626
- To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
627
- If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
 
 
 
628
  """
629
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
  with gr.Row():
631
- is_image = gr.Checkbox(False, label="Generate Image")
632
- length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
633
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
 
635
  with gr.Row():
636
- seed_textbox = gr.Textbox(label="Seed", value=43)
637
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
638
- seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
 
 
 
 
639
 
640
- generate_button = gr.Button(value="Generate", variant='primary')
641
 
642
  with gr.Column():
643
- result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
644
- result_video = gr.Video(label="Generated Animation", interactive=False)
645
  infer_progress = gr.Textbox(
646
- label="Generation Info",
647
  value="No task currently",
648
  interactive=False
649
  )
650
 
651
- is_image.change(
652
- lambda x: gr.update(visible=not x),
653
- inputs=[is_image],
654
- outputs=[length_slider],
 
 
 
 
 
 
 
 
 
 
 
 
655
  )
656
 
657
  generate_button.click(
658
  fn=controller.generate,
659
  inputs=[
 
 
 
 
 
660
  prompt_textbox,
661
  negative_prompt_textbox,
662
  sampler_dropdown,
663
  sample_step_slider,
 
664
  width_slider,
665
  height_slider,
666
- is_image,
 
667
  length_slider,
668
  cfg_scale_slider,
 
 
669
  seed_textbox,
670
  ],
671
  outputs=[result_image, result_video, infer_progress]
@@ -674,31 +1152,51 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample):
674
 
675
 
676
  def post_eas(
 
 
677
  prompt_textbox, negative_prompt_textbox,
678
- sampler_dropdown, sample_step_slider, width_slider, height_slider,
679
- is_image, length_slider, cfg_scale_slider, seed_textbox,
 
680
  ):
 
 
 
 
 
 
 
 
 
 
 
 
681
  datas = {
682
- "base_model_path": "none",
683
- "motion_module_path": "none",
684
- "lora_model_path": "none",
685
- "lora_alpha_slider": 0.55,
686
  "prompt_textbox": prompt_textbox,
687
  "negative_prompt_textbox": negative_prompt_textbox,
688
  "sampler_dropdown": sampler_dropdown,
689
  "sample_step_slider": sample_step_slider,
 
690
  "width_slider": width_slider,
691
  "height_slider": height_slider,
692
- "is_image": is_image,
 
693
  "length_slider": length_slider,
694
  "cfg_scale_slider": cfg_scale_slider,
 
 
695
  "seed_textbox": seed_textbox,
696
  }
697
- # Token可以在公网地址调用信息中获取,详情请参见通用公网调用部分。
698
  session = requests.session()
699
  session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
700
 
701
- response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
 
702
  outputs = response.json()
703
  return outputs
704
 
@@ -710,23 +1208,42 @@ class EasyAnimateController_EAS:
710
 
711
  def generate(
712
  self,
 
 
 
 
 
713
  prompt_textbox,
714
  negative_prompt_textbox,
715
  sampler_dropdown,
716
  sample_step_slider,
 
717
  width_slider,
718
  height_slider,
719
- is_image,
 
720
  length_slider,
721
  cfg_scale_slider,
 
 
722
  seed_textbox
723
  ):
 
 
724
  outputs = post_eas(
 
 
725
  prompt_textbox, negative_prompt_textbox,
726
- sampler_dropdown, sample_step_slider, width_slider, height_slider,
727
- is_image, length_slider, cfg_scale_slider, seed_textbox
 
 
728
  )
729
- base64_encoding = outputs["base64_encoding"]
 
 
 
 
730
  decoded_data = base64.b64decode(base64_encoding)
731
 
732
  if not os.path.exists(self.savedir_sample):
@@ -768,35 +1285,134 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
768
  """
769
  )
770
  with gr.Column(variant="panel"):
771
- prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="This video shows the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
773
 
774
  with gr.Row():
775
  with gr.Column():
776
  with gr.Row():
777
  sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
778
- sample_step_slider = gr.Slider(label="Sampling steps", value=30, minimum=10, maximum=100, step=1)
779
 
780
  if edition == "v1":
781
  width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
782
  height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
783
- with gr.Row():
784
- is_image = gr.Checkbox(False, label="Generate Image", visible=False)
785
- length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
 
 
 
 
 
 
786
  cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
787
  else:
788
- width_slider = gr.Slider(label="Width", value=672, minimum=256, maximum=704, step=16)
789
- height_slider = gr.Slider(label="Height", value=384, minimum=256, maximum=704, step=16)
 
 
 
790
  with gr.Column():
791
  gr.Markdown(
792
  """
793
- To ensure the efficiency of the trial, we will limit the frame rate to no more than 81.
794
- If you want to experience longer video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
 
 
 
795
  """
796
  )
797
- with gr.Row():
798
- is_image = gr.Checkbox(False, label="Generate Image")
799
- length_slider = gr.Slider(label="Animation length", value=72, minimum=9, maximum=81, step=9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
  cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
801
 
802
  with gr.Row():
@@ -819,24 +1435,45 @@ def ui_eas(edition, config_path, model_name, savedir_sample):
819
  interactive=False
820
  )
821
 
822
- is_image.change(
823
- lambda x: gr.update(visible=not x),
824
- inputs=[is_image],
825
- outputs=[length_slider],
 
 
 
 
 
 
 
 
 
 
 
 
826
  )
827
 
828
  generate_button.click(
829
  fn=controller.generate,
830
  inputs=[
 
 
 
 
 
831
  prompt_textbox,
832
  negative_prompt_textbox,
833
  sampler_dropdown,
834
  sample_step_slider,
 
835
  width_slider,
836
  height_slider,
837
- is_image,
 
838
  length_slider,
839
  cfg_scale_slider,
 
 
840
  seed_textbox,
841
  ],
842
  outputs=[result_image, result_video, infer_progress]
 
1
  """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
  """
3
+ import base64
4
  import gc
5
  import json
6
  import os
7
  import random
 
 
 
8
  from datetime import datetime
9
  from glob import glob
10
 
11
  import gradio as gr
 
12
  import numpy as np
13
+ import pkg_resources
14
+ import requests
15
+ import torch
16
  from diffusers import (AutoencoderKL, DDIMScheduler,
17
  DPMSolverMultistepScheduler,
18
  EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
19
  PNDMScheduler)
 
20
  from diffusers.utils.import_utils import is_xformers_available
21
  from omegaconf import OmegaConf
22
+ from PIL import Image
23
  from safetensors import safe_open
24
+ from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
25
+ T5EncoderModel, T5Tokenizer)
26
 
27
+ from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
28
+ from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
29
  from easyanimate.models.transformer3d import Transformer3DModel
30
  from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
31
+ from easyanimate.pipeline.pipeline_easyanimate_inpaint import \
32
+ EasyAnimateInpaintPipeline
33
  from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
34
+ from easyanimate.utils.utils import (
35
+ get_image_to_video_latent,
36
+ get_width_and_height_from_image_and_base_resolution, save_videos_grid)
37
 
 
38
  scheduler_dict = {
39
  "Euler": EulerDiscreteScheduler,
40
  "Euler A": EulerAncestralDiscreteScheduler,
 
65
  self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
66
  self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
67
  self.savedir_sample = os.path.join(self.savedir, "sample")
68
+ self.edition = "v3"
69
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
70
  os.makedirs(self.savedir, exist_ok=True)
71
 
72
  self.diffusion_transformer_list = []
 
90
  self.weight_dtype = torch.bfloat16
91
 
92
  def refresh_diffusion_transformer(self):
93
+ self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
94
 
95
  def refresh_motion_module(self):
96
+ motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
97
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
98
 
99
  def refresh_personalized_model(self):
100
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
101
  self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
102
 
103
  def update_edition(self, edition):
 
105
  self.edition = edition
106
  if edition == "v1":
107
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
108
+ return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
109
+ gr.update(value=512, minimum=384, maximum=704, step=32), \
110
  gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
111
+ elif edition == "v2":
112
  self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
113
+ return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
114
+ gr.update(value=672, minimum=128, maximum=1280, step=16), \
115
  gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
116
+ else:
117
+ self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
118
+ return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
119
+ gr.update(value=672, minimum=128, maximum=1280, step=16), \
120
+ gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
121
 
122
  def update_diffusion_transformer(self, diffusion_transformer_dropdown):
123
  print("Update diffusion transformer")
124
  if diffusion_transformer_dropdown == "none":
125
+ return gr.update()
126
  if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
127
  Choosen_AutoencoderKL = AutoencoderKLMagvit
128
  else:
 
140
  self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
141
 
142
  # Get pipeline
143
+ if self.transformer.config.in_channels != 12:
144
+ self.pipeline = EasyAnimatePipeline(
145
+ vae=self.vae,
146
+ text_encoder=self.text_encoder,
147
+ tokenizer=self.tokenizer,
148
+ transformer=self.transformer,
149
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
150
+ )
151
+ else:
152
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
153
+ diffusion_transformer_dropdown, subfolder="image_encoder"
154
+ ).to("cuda", self.weight_dtype)
155
+ clip_image_processor = CLIPImageProcessor.from_pretrained(
156
+ diffusion_transformer_dropdown, subfolder="image_encoder"
157
+ )
158
+ self.pipeline = EasyAnimateInpaintPipeline(
159
+ vae=self.vae,
160
+ text_encoder=self.text_encoder,
161
+ tokenizer=self.tokenizer,
162
+ transformer=self.transformer,
163
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
164
+ clip_image_encoder=clip_image_encoder,
165
+ clip_image_processor=clip_image_processor,
166
+ )
167
+
168
  print("Update diffusion transformer done")
169
+ return gr.update()
170
 
171
  def update_motion_module(self, motion_module_dropdown):
172
  self.motion_module_path = motion_module_dropdown
173
  print("Update motion module")
174
  if motion_module_dropdown == "none":
175
+ return gr.update()
176
  if self.transformer is None:
177
  gr.Info(f"Please select a pretrained model path.")
178
+ return gr.update(value=None)
179
  else:
180
  motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
181
  if motion_module_dropdown.endswith(".safetensors"):
 
187
  motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
188
  missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
189
  print("Update motion module done.")
190
+ return gr.update()
191
 
192
  def update_base_model(self, base_model_dropdown):
193
  self.base_model_path = base_model_dropdown
194
  print("Update base model")
195
  if base_model_dropdown == "none":
196
+ return gr.update()
197
  if self.transformer is None:
198
  gr.Info(f"Please select a pretrained model path.")
199
+ return gr.update(value=None)
200
  else:
201
  base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
202
  base_model_state_dict = {}
 
205
  base_model_state_dict[key] = f.get_tensor(key)
206
  self.transformer.load_state_dict(base_model_state_dict, strict=False)
207
  print("Update base done")
208
+ return gr.update()
209
 
210
  def update_lora_model(self, lora_model_dropdown):
211
  print("Update lora model")
212
  if lora_model_dropdown == "none":
213
  self.lora_model_path = "none"
214
+ return gr.update()
215
  lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
216
  self.lora_model_path = lora_model_dropdown
217
+ return gr.update()
218
 
219
  def generate(
220
  self,
 
227
  negative_prompt_textbox,
228
  sampler_dropdown,
229
  sample_step_slider,
230
+ resize_method,
231
  width_slider,
232
  height_slider,
233
+ base_resolution,
234
+ generation_method,
235
  length_slider,
236
+ overlap_video_length,
237
+ partial_video_length,
238
  cfg_scale_slider,
239
+ start_image,
240
+ end_image,
241
  seed_textbox,
242
  is_api = False,
243
  ):
244
+ gc.collect()
245
+ torch.cuda.empty_cache()
246
+ torch.cuda.ipc_collect()
247
+
248
  if self.transformer is None:
249
  raise gr.Error(f"Please select a pretrained model path.")
250
 
 
257
  if self.lora_model_path != lora_model_dropdown:
258
  print("Update lora model")
259
  self.update_lora_model(lora_model_dropdown)
260
+
261
+ if resize_method == "Resize to the Start Image":
262
+ if start_image is None:
263
+ if is_api:
264
+ return "", f"Please upload an image when using \"Resize to the Start Image\"."
265
+ else:
266
+ raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
267
+
268
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
269
+
270
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
271
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
272
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
273
+
274
+ if self.transformer.config.in_channels != 12 and start_image is not None:
275
+ if is_api:
276
+ return "", f"Please select an image to video pretrained model while using image to video."
277
+ else:
278
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
279
+
280
+ if self.transformer.config.in_channels != 12 and generation_method == "Long Video Generation":
281
+ if is_api:
282
+ return "", f"Please select an image to video pretrained model while using long video generation."
283
+ else:
284
+ raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
285
+
286
+ if start_image is None and end_image is not None:
287
+ if is_api:
288
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
289
+ else:
290
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
291
+
292
+ is_image = True if generation_method == "Image Generation" else False
293
 
294
  if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
295
 
 
304
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
305
 
306
  try:
307
+ if self.transformer.config.in_channels == 12:
308
+ if generation_method == "Long Video Generation":
309
+ init_frames = 0
310
+ last_frames = init_frames + partial_video_length
311
+ while init_frames < length_slider:
312
+ if last_frames >= length_slider:
313
+ if self.pipeline.vae.quant_conv.weight.ndim==5:
314
+ mini_batch_encoder = self.pipeline.vae.mini_batch_encoder
315
+ _partial_video_length = length_slider - init_frames
316
+ _partial_video_length = int(_partial_video_length // mini_batch_encoder * mini_batch_encoder)
317
+ else:
318
+ _partial_video_length = length_slider - init_frames
319
+
320
+ if _partial_video_length <= 0:
321
+ break
322
+ else:
323
+ _partial_video_length = partial_video_length
324
+
325
+ if last_frames >= length_slider:
326
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
327
+ else:
328
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
329
+
330
+ with torch.no_grad():
331
+ sample = self.pipeline(
332
+ prompt_textbox,
333
+ negative_prompt = negative_prompt_textbox,
334
+ num_inference_steps = sample_step_slider,
335
+ guidance_scale = cfg_scale_slider,
336
+ width = width_slider,
337
+ height = height_slider,
338
+ video_length = _partial_video_length,
339
+ generator = generator,
340
+
341
+ video = input_video,
342
+ mask_video = input_video_mask,
343
+ clip_image = clip_image,
344
+ strength = 1,
345
+ ).videos
346
+
347
+ if init_frames != 0:
348
+ mix_ratio = torch.from_numpy(
349
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
350
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
351
+
352
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
353
+ sample[:, :, :overlap_video_length] * mix_ratio
354
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
355
+
356
+ sample = new_sample
357
+ else:
358
+ new_sample = sample
359
+
360
+ if last_frames >= length_slider:
361
+ break
362
+
363
+ start_image = [
364
+ Image.fromarray(
365
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
366
+ ) for _index in range(-overlap_video_length, 0)
367
+ ]
368
+
369
+ init_frames = init_frames + _partial_video_length - overlap_video_length
370
+ last_frames = init_frames + _partial_video_length
371
+ else:
372
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
373
+
374
+ sample = self.pipeline(
375
+ prompt_textbox,
376
+ negative_prompt = negative_prompt_textbox,
377
+ num_inference_steps = sample_step_slider,
378
+ guidance_scale = cfg_scale_slider,
379
+ width = width_slider,
380
+ height = height_slider,
381
+ video_length = length_slider if not is_image else 1,
382
+ generator = generator,
383
+
384
+ video = input_video,
385
+ mask_video = input_video_mask,
386
+ clip_image = clip_image,
387
+ ).videos
388
+ else:
389
+ sample = self.pipeline(
390
+ prompt_textbox,
391
+ negative_prompt = negative_prompt_textbox,
392
+ num_inference_steps = sample_step_slider,
393
+ guidance_scale = cfg_scale_slider,
394
+ width = width_slider,
395
+ height = height_slider,
396
+ video_length = length_slider if not is_image else 1,
397
+ generator = generator
398
+ ).videos
399
  except Exception as e:
400
  gc.collect()
401
  torch.cuda.empty_cache()
 
405
  if is_api:
406
  return "", f"Error. error information is {str(e)}"
407
  else:
408
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
409
+
410
+ gc.collect()
411
+ torch.cuda.empty_cache()
412
+ torch.cuda.ipc_collect()
413
 
414
  # lora part
415
  if self.lora_model_path != "none":
 
451
  if is_api:
452
  return save_sample_path, "Success"
453
  else:
454
+ if gradio_version_is_above_4:
455
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
456
+ else:
457
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
458
  else:
459
  save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
460
  save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
 
462
  if is_api:
463
  return save_sample_path, "Success"
464
  else:
465
+ if gradio_version_is_above_4:
466
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
467
+ else:
468
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
469
 
470
 
471
  def ui():
 
486
  with gr.Column(variant="panel"):
487
  gr.Markdown(
488
  """
489
+ ### 1. EasyAnimate Edition (EasyAnimate版本).
490
  """
491
  )
492
  with gr.Row():
493
  easyanimate_edition_dropdown = gr.Dropdown(
494
+ label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
495
+ choices=["v1", "v2", "v3"],
496
+ value="v3",
497
  interactive=True,
498
  )
499
  gr.Markdown(
500
  """
501
+ ### 2. Model checkpoints (模型路径).
502
  """
503
  )
504
  with gr.Row():
505
  diffusion_transformer_dropdown = gr.Dropdown(
506
+ label="Pretrained Model Path (预训练模型路径)",
507
  choices=controller.diffusion_transformer_list,
508
  value="none",
509
  interactive=True,
 
517
  diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
518
  def refresh_diffusion_transformer():
519
  controller.refresh_diffusion_transformer()
520
+ return gr.update(choices=controller.diffusion_transformer_list)
521
  diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
522
 
523
  with gr.Row():
524
  motion_module_dropdown = gr.Dropdown(
525
+ label="Select motion module (选择运动模块[非必需])",
526
  choices=controller.motion_module_list,
527
  value="none",
528
  interactive=True,
 
532
  motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
533
  def update_motion_module():
534
  controller.refresh_motion_module()
535
+ return gr.update(choices=controller.motion_module_list)
536
  motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
537
 
538
  base_model_dropdown = gr.Dropdown(
539
+ label="Select base Dreambooth model (选择基模型[非必需])",
540
  choices=controller.personalized_model_list,
541
  value="none",
542
  interactive=True,
543
  )
544
 
545
  lora_model_dropdown = gr.Dropdown(
546
+ label="Select LoRA model (选择LoRA模型[非必需])",
547
  choices=["none"] + controller.personalized_model_list,
548
  value="none",
549
  interactive=True,
550
  )
551
 
552
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
553
 
554
  personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
555
  def update_personalized_model():
556
  controller.refresh_personalized_model()
557
  return [
558
+ gr.update(choices=controller.personalized_model_list),
559
+ gr.update(choices=["none"] + controller.personalized_model_list)
560
  ]
561
  personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
562
 
563
  with gr.Column(variant="panel"):
564
  gr.Markdown(
565
  """
566
+ ### 3. Configs for Generation (生成参数配置).
567
  """
568
  )
569
 
570
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
571
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
572
 
573
  with gr.Row():
574
  with gr.Column():
575
  with gr.Row():
576
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
577
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=30, minimum=10, maximum=100, step=1)
578
 
579
+ resize_method = gr.Radio(
580
+ ["Generate by", "Resize to the Start Image"],
581
+ value="Generate by",
582
+ show_label=False,
583
+ )
584
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16)
585
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16)
586
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
587
+
588
+ with gr.Group():
589
+ generation_method = gr.Radio(
590
+ ["Video Generation", "Image Generation", "Long Video Generation"],
591
+ value="Video Generation",
592
+ show_label=False,
593
+ )
594
+ with gr.Row():
595
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8, maximum=144, step=8)
596
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
597
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8, maximum=144, step=8, visible=False)
598
+
599
+ with gr.Accordion("Image to Video (图片到视频)", open=False):
600
+ start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
601
+
602
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
603
+ def select_template(evt: gr.SelectData):
604
+ text = {
605
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
606
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
607
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
608
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
609
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
610
+ }[template_gallery_path[evt.index]]
611
+ return template_gallery_path[evt.index], text
612
+
613
+ template_gallery = gr.Gallery(
614
+ template_gallery_path,
615
+ columns=5, rows=1,
616
+ height=140,
617
+ allow_preview=False,
618
+ container=False,
619
+ label="Template Examples",
620
+ )
621
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
622
+
623
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
624
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
625
+
626
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
627
 
628
  with gr.Row():
629
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
630
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
631
+ seed_button.click(
632
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
633
+ inputs=[],
634
+ outputs=[seed_textbox]
635
+ )
636
 
637
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
638
 
639
  with gr.Column():
640
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
641
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
642
  infer_progress = gr.Textbox(
643
+ label="Generation Info (生成信息)",
644
  value="No task currently",
645
  interactive=False
646
  )
647
 
648
+ def upload_generation_method(generation_method):
649
+ if generation_method == "Video Generation":
650
+ return [gr.update(visible=True, maximum=144, value=144), gr.update(visible=False), gr.update(visible=False)]
651
+ elif generation_method == "Image Generation":
652
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
653
+ else:
654
+ return [gr.update(visible=True, maximum=1440), gr.update(visible=True), gr.update(visible=True)]
655
+ generation_method.change(
656
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
657
+ )
658
+
659
+ def upload_resize_method(resize_method):
660
+ if resize_method == "Generate by":
661
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
662
+ else:
663
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
664
+ resize_method.change(
665
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
666
  )
667
+
668
  easyanimate_edition_dropdown.change(
669
  fn=controller.update_edition,
670
  inputs=[easyanimate_edition_dropdown],
 
673
  diffusion_transformer_dropdown,
674
  motion_module_dropdown,
675
  motion_module_refresh_button,
 
676
  width_slider,
677
  height_slider,
678
  length_slider,
 
690
  negative_prompt_textbox,
691
  sampler_dropdown,
692
  sample_step_slider,
693
+ resize_method,
694
  width_slider,
695
  height_slider,
696
+ base_resolution,
697
+ generation_method,
698
  length_slider,
699
+ overlap_video_length,
700
+ partial_video_length,
701
  cfg_scale_slider,
702
+ start_image,
703
+ end_image,
704
  seed_textbox,
705
  ],
706
  outputs=[result_image, result_video, infer_progress]
 
710
 
711
  class EasyAnimateController_Modelscope:
712
  def __init__(self, edition, config_path, model_name, savedir_sample):
713
+ # Weight Dtype
714
+ weight_dtype = torch.bfloat16
715
+
716
+ # Basic dir
717
+ self.basedir = os.getcwd()
718
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
719
+ self.lora_model_path = "none"
720
+ self.savedir_sample = savedir_sample
721
+ self.refresh_personalized_model()
722
  os.makedirs(self.savedir_sample, exist_ok=True)
723
 
724
+ # Config and model path
725
  self.edition = edition
726
  self.inference_config = OmegaConf.load(config_path)
727
  # Get Transformer
 
747
  subfolder="text_encoder",
748
  torch_dtype=weight_dtype
749
  )
750
+ # Get pipeline
751
+ if self.transformer.config.in_channels != 12:
752
+ self.pipeline = EasyAnimatePipeline(
753
+ vae=self.vae,
754
+ text_encoder=self.text_encoder,
755
+ tokenizer=self.tokenizer,
756
+ transformer=self.transformer,
757
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
758
+ )
759
+ else:
760
+ clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
761
+ model_name, subfolder="image_encoder"
762
+ ).to("cuda", weight_dtype)
763
+ clip_image_processor = CLIPImageProcessor.from_pretrained(
764
+ model_name, subfolder="image_encoder"
765
+ )
766
+ self.pipeline = EasyAnimateInpaintPipeline(
767
+ vae=self.vae,
768
+ text_encoder=self.text_encoder,
769
+ tokenizer=self.tokenizer,
770
+ transformer=self.transformer,
771
+ scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
772
+ clip_image_encoder=clip_image_encoder,
773
+ clip_image_processor=clip_image_processor,
774
+ )
775
+
776
  print("Update diffusion transformer done")
777
 
778
+
779
+ def refresh_personalized_model(self):
780
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
781
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
782
+
783
+
784
+ def update_lora_model(self, lora_model_dropdown):
785
+ print("Update lora model")
786
+ if lora_model_dropdown == "none":
787
+ self.lora_model_path = "none"
788
+ return gr.update()
789
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
790
+ self.lora_model_path = lora_model_dropdown
791
+ return gr.update()
792
+
793
+
794
  def generate(
795
  self,
796
+ diffusion_transformer_dropdown,
797
+ motion_module_dropdown,
798
+ base_model_dropdown,
799
+ lora_model_dropdown,
800
+ lora_alpha_slider,
801
  prompt_textbox,
802
  negative_prompt_textbox,
803
  sampler_dropdown,
804
  sample_step_slider,
805
+ resize_method,
806
  width_slider,
807
  height_slider,
808
+ base_resolution,
809
+ generation_method,
810
  length_slider,
811
  cfg_scale_slider,
812
+ start_image,
813
+ end_image,
814
+ seed_textbox,
815
+ is_api = False,
816
  ):
817
+ gc.collect()
818
+ torch.cuda.empty_cache()
819
+ torch.cuda.ipc_collect()
820
+
821
+ if self.transformer is None:
822
+ raise gr.Error(f"Please select a pretrained model path.")
823
+
824
+ if self.lora_model_path != lora_model_dropdown:
825
+ print("Update lora model")
826
+ self.update_lora_model(lora_model_dropdown)
827
+
828
+ if resize_method == "Resize to the Start Image":
829
+ if start_image is None:
830
+ raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
831
+
832
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
833
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
834
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
835
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
836
+
837
+ if self.transformer.config.in_channels != 12 and start_image is not None:
838
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
839
+
840
+ if start_image is None and end_image is not None:
841
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
842
+
843
+ is_image = True if generation_method == "Image Generation" else False
844
+
845
  if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
846
 
847
  self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
848
+ if self.lora_model_path != "none":
849
+ # lora part
850
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
851
  self.pipeline.to("cuda")
852
 
853
  if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
 
855
  generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
856
 
857
  try:
858
+ if self.transformer.config.in_channels == 12:
859
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
860
+
861
+ sample = self.pipeline(
862
+ prompt_textbox,
863
+ negative_prompt = negative_prompt_textbox,
864
+ num_inference_steps = sample_step_slider,
865
+ guidance_scale = cfg_scale_slider,
866
+ width = width_slider,
867
+ height = height_slider,
868
+ video_length = length_slider if not is_image else 1,
869
+ generator = generator,
870
+
871
+ video = input_video,
872
+ mask_video = input_video_mask,
873
+ clip_image = clip_image,
874
+ ).videos
875
+ else:
876
+ sample = self.pipeline(
877
+ prompt_textbox,
878
+ negative_prompt = negative_prompt_textbox,
879
+ num_inference_steps = sample_step_slider,
880
+ guidance_scale = cfg_scale_slider,
881
+ width = width_slider,
882
+ height = height_slider,
883
+ video_length = length_slider if not is_image else 1,
884
+ generator = generator
885
+ ).videos
886
  except Exception as e:
887
  gc.collect()
888
  torch.cuda.empty_cache()
889
  torch.cuda.ipc_collect()
890
+ if self.lora_model_path != "none":
891
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
892
+ if is_api:
893
+ return "", f"Error. error information is {str(e)}"
894
+ else:
895
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
896
+
897
+ gc.collect()
898
+ torch.cuda.empty_cache()
899
+ torch.cuda.ipc_collect()
900
+
901
+ # lora part
902
+ if self.lora_model_path != "none":
903
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
904
 
905
  if not os.path.exists(self.savedir_sample):
906
  os.makedirs(self.savedir_sample, exist_ok=True)
 
918
  image = (image * 255).numpy().astype(np.uint8)
919
  image = Image.fromarray(image)
920
  image.save(save_sample_path)
921
+ if is_api:
922
+ return save_sample_path, "Success"
923
+ else:
924
+ if gradio_version_is_above_4:
925
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
926
+ else:
927
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
928
  else:
929
  save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
930
  save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
931
+ if is_api:
932
+ return save_sample_path, "Success"
933
+ else:
934
+ if gradio_version_is_above_4:
935
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
936
+ else:
937
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
938
 
939
 
940
  def ui_modelscope(edition, config_path, model_name, savedir_sample):
 
953
  """
954
  )
955
  with gr.Column(variant="panel"):
956
+ gr.Markdown(
957
+ """
958
+ ### 1. Model checkpoints (模型路径).
959
+ """
960
+ )
961
+ with gr.Row():
962
+ diffusion_transformer_dropdown = gr.Dropdown(
963
+ label="Pretrained Model Path (预训练模型路径)",
964
+ choices=[model_name],
965
+ value=model_name,
966
+ interactive=False,
967
+ )
968
+ with gr.Row():
969
+ motion_module_dropdown = gr.Dropdown(
970
+ label="Select motion module (选择运动模块[非必需])",
971
+ choices=["none"],
972
+ value="none",
973
+ interactive=False,
974
+ visible=False
975
+ )
976
+ base_model_dropdown = gr.Dropdown(
977
+ label="Select base Dreambooth model (选择基模型[非必需])",
978
+ choices=["none"],
979
+ value="none",
980
+ interactive=False,
981
+ visible=False
982
+ )
983
+ with gr.Column(visible=False):
984
+ gr.Markdown(
985
+ """
986
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
987
+ """
988
+ )
989
+ with gr.Row():
990
+ lora_model_dropdown = gr.Dropdown(
991
+ label="Select LoRA model",
992
+ choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
993
+ value="none",
994
+ interactive=True,
995
+ )
996
+
997
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
998
+
999
+ with gr.Column(variant="panel"):
1000
+ gr.Markdown(
1001
+ """
1002
+ ### 2. Configs for Generation (生成参数配置).
1003
+ """
1004
+ )
1005
+
1006
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1007
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
1008
 
1009
  with gr.Row():
1010
  with gr.Column():
1011
  with gr.Row():
1012
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
1013
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=20, minimum=10, maximum=30, step=1, interactive=False)
1014
 
1015
  if edition == "v1":
1016
+ width_slider = gr.Slider(label="Width (视频宽度)", value=512, minimum=384, maximum=704, step=32)
1017
+ height_slider = gr.Slider(label="Height (视频高度)", value=512, minimum=384, maximum=704, step=32)
1018
+
1019
+ with gr.Group():
1020
+ generation_method = gr.Radio(
1021
+ ["Video Generation", "Image Generation"],
1022
+ value="Video Generation",
1023
+ show_label=False,
1024
+ visible=False,
1025
+ )
1026
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=80, minimum=40, maximum=96, step=1)
1027
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
1028
  else:
1029
+ resize_method = gr.Radio(
1030
+ ["Generate by", "Resize to the Start Image"],
1031
+ value="Generate by",
1032
+ show_label=False,
1033
+ )
1034
  with gr.Column():
1035
  gr.Markdown(
1036
  """
1037
+ We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s).
1038
+
1039
+ If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above.
1040
+
1041
+ If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
1042
  """
1043
  )
1044
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
1045
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
1046
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
1047
+
1048
+ with gr.Group():
1049
+ generation_method = gr.Radio(
1050
+ ["Video Generation", "Image Generation"],
1051
+ value="Video Generation",
1052
+ show_label=False,
1053
+ visible=True,
1054
+ )
1055
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8)
1056
+
1057
+ with gr.Accordion("Image to Video (图片到视频)", open=True):
1058
  with gr.Row():
1059
+ start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
1060
+
1061
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1062
+ def select_template(evt: gr.SelectData):
1063
+ text = {
1064
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1065
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1066
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1067
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1068
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1069
+ }[template_gallery_path[evt.index]]
1070
+ return template_gallery_path[evt.index], text
1071
+
1072
+ template_gallery = gr.Gallery(
1073
+ template_gallery_path,
1074
+ columns=5, rows=1,
1075
+ height=140,
1076
+ allow_preview=False,
1077
+ container=False,
1078
+ label="Template Examples",
1079
+ )
1080
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
1081
+
1082
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
1083
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
1084
+
1085
+
1086
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=7.0, minimum=0, maximum=20)
1087
 
1088
  with gr.Row():
1089
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
1090
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
1091
+ seed_button.click(
1092
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
1093
+ inputs=[],
1094
+ outputs=[seed_textbox]
1095
+ )
1096
 
1097
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
1098
 
1099
  with gr.Column():
1100
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
1101
+ result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
1102
  infer_progress = gr.Textbox(
1103
+ label="Generation Info (生成信息)",
1104
  value="No task currently",
1105
  interactive=False
1106
  )
1107
 
1108
+ def upload_generation_method(generation_method):
1109
+ if generation_method == "Video Generation":
1110
+ return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
1111
+ elif generation_method == "Image Generation":
1112
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1113
+ generation_method.change(
1114
+ upload_generation_method, generation_method, [length_slider]
1115
+ )
1116
+
1117
+ def upload_resize_method(resize_method):
1118
+ if resize_method == "Generate by":
1119
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1120
+ else:
1121
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1122
+ resize_method.change(
1123
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1124
  )
1125
 
1126
  generate_button.click(
1127
  fn=controller.generate,
1128
  inputs=[
1129
+ diffusion_transformer_dropdown,
1130
+ motion_module_dropdown,
1131
+ base_model_dropdown,
1132
+ lora_model_dropdown,
1133
+ lora_alpha_slider,
1134
  prompt_textbox,
1135
  negative_prompt_textbox,
1136
  sampler_dropdown,
1137
  sample_step_slider,
1138
+ resize_method,
1139
  width_slider,
1140
  height_slider,
1141
+ base_resolution,
1142
+ generation_method,
1143
  length_slider,
1144
  cfg_scale_slider,
1145
+ start_image,
1146
+ end_image,
1147
  seed_textbox,
1148
  ],
1149
  outputs=[result_image, result_video, infer_progress]
 
1152
 
1153
 
1154
  def post_eas(
1155
+ diffusion_transformer_dropdown, motion_module_dropdown,
1156
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1157
  prompt_textbox, negative_prompt_textbox,
1158
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1159
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1160
+ start_image, end_image, seed_textbox,
1161
  ):
1162
+ if start_image is not None:
1163
+ with open(start_image, 'rb') as file:
1164
+ file_content = file.read()
1165
+ start_image_encoded_content = base64.b64encode(file_content)
1166
+ start_image = start_image_encoded_content.decode('utf-8')
1167
+
1168
+ if end_image is not None:
1169
+ with open(end_image, 'rb') as file:
1170
+ file_content = file.read()
1171
+ end_image_encoded_content = base64.b64encode(file_content)
1172
+ end_image = end_image_encoded_content.decode('utf-8')
1173
+
1174
  datas = {
1175
+ "base_model_path": base_model_dropdown,
1176
+ "motion_module_path": motion_module_dropdown,
1177
+ "lora_model_path": lora_model_dropdown,
1178
+ "lora_alpha_slider": lora_alpha_slider,
1179
  "prompt_textbox": prompt_textbox,
1180
  "negative_prompt_textbox": negative_prompt_textbox,
1181
  "sampler_dropdown": sampler_dropdown,
1182
  "sample_step_slider": sample_step_slider,
1183
+ "resize_method": resize_method,
1184
  "width_slider": width_slider,
1185
  "height_slider": height_slider,
1186
+ "base_resolution": base_resolution,
1187
+ "generation_method": generation_method,
1188
  "length_slider": length_slider,
1189
  "cfg_scale_slider": cfg_scale_slider,
1190
+ "start_image": start_image,
1191
+ "end_image": end_image,
1192
  "seed_textbox": seed_textbox,
1193
  }
1194
+
1195
  session = requests.session()
1196
  session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
1197
 
1198
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas, timeout=300)
1199
+
1200
  outputs = response.json()
1201
  return outputs
1202
 
 
1208
 
1209
  def generate(
1210
  self,
1211
+ diffusion_transformer_dropdown,
1212
+ motion_module_dropdown,
1213
+ base_model_dropdown,
1214
+ lora_model_dropdown,
1215
+ lora_alpha_slider,
1216
  prompt_textbox,
1217
  negative_prompt_textbox,
1218
  sampler_dropdown,
1219
  sample_step_slider,
1220
+ resize_method,
1221
  width_slider,
1222
  height_slider,
1223
+ base_resolution,
1224
+ generation_method,
1225
  length_slider,
1226
  cfg_scale_slider,
1227
+ start_image,
1228
+ end_image,
1229
  seed_textbox
1230
  ):
1231
+ is_image = True if generation_method == "Image Generation" else False
1232
+
1233
  outputs = post_eas(
1234
+ diffusion_transformer_dropdown, motion_module_dropdown,
1235
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
1236
  prompt_textbox, negative_prompt_textbox,
1237
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
1238
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
1239
+ start_image, end_image,
1240
+ seed_textbox
1241
  )
1242
+ try:
1243
+ base64_encoding = outputs["base64_encoding"]
1244
+ except:
1245
+ return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
1246
+
1247
  decoded_data = base64.b64decode(base64_encoding)
1248
 
1249
  if not os.path.exists(self.savedir_sample):
 
1285
  """
1286
  )
1287
  with gr.Column(variant="panel"):
1288
+ gr.Markdown(
1289
+ """
1290
+ ### 1. Model checkpoints.
1291
+ """
1292
+ )
1293
+ with gr.Row():
1294
+ diffusion_transformer_dropdown = gr.Dropdown(
1295
+ label="Pretrained Model Path",
1296
+ choices=[model_name],
1297
+ value=model_name,
1298
+ interactive=False,
1299
+ )
1300
+ with gr.Row():
1301
+ motion_module_dropdown = gr.Dropdown(
1302
+ label="Select motion module",
1303
+ choices=["none"],
1304
+ value="none",
1305
+ interactive=False,
1306
+ visible=False
1307
+ )
1308
+ base_model_dropdown = gr.Dropdown(
1309
+ label="Select base Dreambooth model",
1310
+ choices=["none"],
1311
+ value="none",
1312
+ interactive=False,
1313
+ visible=False
1314
+ )
1315
+ with gr.Column(visible=False):
1316
+ gr.Markdown(
1317
+ """
1318
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
1319
+ """
1320
+ )
1321
+ with gr.Row():
1322
+ lora_model_dropdown = gr.Dropdown(
1323
+ label="Select LoRA model",
1324
+ choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
1325
+ value="none",
1326
+ interactive=True,
1327
+ )
1328
+
1329
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
1330
+
1331
+ with gr.Column(variant="panel"):
1332
+ gr.Markdown(
1333
+ """
1334
+ ### 2. Configs for Generation.
1335
+ """
1336
+ )
1337
+
1338
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
1339
  negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
1340
 
1341
  with gr.Row():
1342
  with gr.Column():
1343
  with gr.Row():
1344
  sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
1345
+ sample_step_slider = gr.Slider(label="Sampling steps", value=20, minimum=10, maximum=30, step=1, interactive=False)
1346
 
1347
  if edition == "v1":
1348
  width_slider = gr.Slider(label="Width", value=512, minimum=384, maximum=704, step=32)
1349
  height_slider = gr.Slider(label="Height", value=512, minimum=384, maximum=704, step=32)
1350
+
1351
+ with gr.Group():
1352
+ generation_method = gr.Radio(
1353
+ ["Video Generation", "Image Generation"],
1354
+ value="Video Generation",
1355
+ show_label=False,
1356
+ visible=False,
1357
+ )
1358
+ length_slider = gr.Slider(label="Animation length", value=80, minimum=40, maximum=96, step=1)
1359
  cfg_scale_slider = gr.Slider(label="CFG Scale", value=6.0, minimum=0, maximum=20)
1360
  else:
1361
+ resize_method = gr.Radio(
1362
+ ["Generate by", "Resize to the Start Image"],
1363
+ value="Generate by",
1364
+ show_label=False,
1365
+ )
1366
  with gr.Column():
1367
  gr.Markdown(
1368
  """
1369
+ We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s).
1370
+
1371
+ If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above.
1372
+
1373
+ If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/).
1374
  """
1375
  )
1376
+ width_slider = gr.Slider(label="Width (视频宽度)", value=672, minimum=128, maximum=1280, step=16, interactive=False)
1377
+ height_slider = gr.Slider(label="Height (视频高度)", value=384, minimum=128, maximum=1280, step=16, interactive=False)
1378
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
1379
+
1380
+ with gr.Group():
1381
+ generation_method = gr.Radio(
1382
+ ["Video Generation", "Image Generation"],
1383
+ value="Video Generation",
1384
+ show_label=False,
1385
+ visible=True,
1386
+ )
1387
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8, maximum=48, step=8)
1388
+
1389
+ with gr.Accordion("Image to Video", open=True):
1390
+ start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
1391
+
1392
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
1393
+ def select_template(evt: gr.SelectData):
1394
+ text = {
1395
+ "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1396
+ "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1397
+ "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1398
+ "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1399
+ "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
1400
+ }[template_gallery_path[evt.index]]
1401
+ return template_gallery_path[evt.index], text
1402
+
1403
+ template_gallery = gr.Gallery(
1404
+ template_gallery_path,
1405
+ columns=5, rows=1,
1406
+ height=140,
1407
+ allow_preview=False,
1408
+ container=False,
1409
+ label="Template Examples",
1410
+ )
1411
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
1412
+
1413
+ with gr.Accordion("The image at the ending of the video (Optional)", open=False):
1414
+ end_image = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
1415
+
1416
  cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.0, minimum=0, maximum=20)
1417
 
1418
  with gr.Row():
 
1435
  interactive=False
1436
  )
1437
 
1438
+ def upload_generation_method(generation_method):
1439
+ if generation_method == "Video Generation":
1440
+ return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
1441
+ elif generation_method == "Image Generation":
1442
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
1443
+ generation_method.change(
1444
+ upload_generation_method, generation_method, [length_slider]
1445
+ )
1446
+
1447
+ def upload_resize_method(resize_method):
1448
+ if resize_method == "Generate by":
1449
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
1450
+ else:
1451
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
1452
+ resize_method.change(
1453
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
1454
  )
1455
 
1456
  generate_button.click(
1457
  fn=controller.generate,
1458
  inputs=[
1459
+ diffusion_transformer_dropdown,
1460
+ motion_module_dropdown,
1461
+ base_model_dropdown,
1462
+ lora_model_dropdown,
1463
+ lora_alpha_slider,
1464
  prompt_textbox,
1465
  negative_prompt_textbox,
1466
  sampler_dropdown,
1467
  sample_step_slider,
1468
+ resize_method,
1469
  width_slider,
1470
  height_slider,
1471
+ base_resolution,
1472
+ generation_method,
1473
  length_slider,
1474
  cfg_scale_slider,
1475
+ start_image,
1476
+ end_image,
1477
  seed_textbox,
1478
  ],
1479
  outputs=[result_image, result_video, infer_progress]
easyanimate/utils/utils.py CHANGED
@@ -8,6 +8,13 @@ import cv2
8
  from einops import rearrange
9
  from PIL import Image
10
 
 
 
 
 
 
 
 
11
 
12
  def color_transfer(sc, dc):
13
  """
@@ -62,3 +69,103 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
62
  if path.endswith("mp4"):
63
  path = path.replace('.mp4', '.gif')
64
  outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from einops import rearrange
9
  from PIL import Image
10
 
11
+ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
12
+ target_pixels = int(base_resolution) * int(base_resolution)
13
+ original_width, original_height = Image.open(image).size
14
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
15
+ width_slider = round(original_width * ratio)
16
+ height_slider = round(original_height * ratio)
17
+ return height_slider, width_slider
18
 
19
  def color_transfer(sc, dc):
20
  """
 
69
  if path.endswith("mp4"):
70
  path = path.replace('.mp4', '.gif')
71
  outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
72
+
73
+ def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
74
+ if validation_image_start is not None and validation_image_end is not None:
75
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
76
+ image_start = clip_image = Image.open(validation_image_start)
77
+ else:
78
+ image_start = clip_image = validation_image_start
79
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
80
+ image_end = Image.open(validation_image_end)
81
+ else:
82
+ image_end = validation_image_end
83
+
84
+ if type(image_start) is list:
85
+ clip_image = clip_image[0]
86
+ start_video = torch.cat(
87
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
88
+ dim=2
89
+ )
90
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
91
+ input_video[:, :, :len(image_start)] = start_video
92
+
93
+ input_video_mask = torch.zeros_like(input_video[:, :1])
94
+ input_video_mask[:, :, len(image_start):] = 255
95
+ else:
96
+ input_video = torch.tile(
97
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
98
+ [1, 1, video_length, 1, 1]
99
+ )
100
+ input_video_mask = torch.zeros_like(input_video[:, :1])
101
+ input_video_mask[:, :, 1:] = 255
102
+
103
+ if type(image_end) is list:
104
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
105
+ end_video = torch.cat(
106
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
107
+ dim=2
108
+ )
109
+ input_video[:, :, -len(end_video):] = end_video
110
+
111
+ input_video_mask[:, :, -len(image_end):] = 0
112
+ else:
113
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
114
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
115
+ input_video_mask[:, :, -1:] = 0
116
+
117
+ input_video = input_video / 255
118
+
119
+ elif validation_image_start is not None:
120
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
121
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
122
+ else:
123
+ image_start = clip_image = validation_image_start
124
+
125
+ if type(image_start) is list:
126
+ clip_image = clip_image[0]
127
+ start_video = torch.cat(
128
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
129
+ dim=2
130
+ )
131
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
132
+ input_video[:, :, :len(image_start)] = start_video
133
+ input_video = input_video / 255
134
+
135
+ input_video_mask = torch.zeros_like(input_video[:, :1])
136
+ input_video_mask[:, :, len(image_start):] = 255
137
+ else:
138
+ input_video = torch.tile(
139
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
140
+ [1, 1, video_length, 1, 1]
141
+ ) / 255
142
+ input_video_mask = torch.zeros_like(input_video[:, :1])
143
+ input_video_mask[:, :, 1:, ] = 255
144
+ else:
145
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
146
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
147
+ clip_image = None
148
+
149
+ return input_video, input_video_mask, clip_image
150
+
151
+ def video_frames(input_video_path):
152
+ cap = cv2.VideoCapture(input_video_path)
153
+ frames = []
154
+ while True:
155
+ ret, frame = cap.read()
156
+ if not ret:
157
+ break
158
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
159
+ cap.release()
160
+ cv2.destroyAllWindows()
161
+ return frames
162
+
163
+ def get_video_to_video_latent(validation_videos, video_length):
164
+ input_video = video_frames(validation_videos)
165
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
166
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
167
+
168
+ input_video_mask = torch.zeros_like(input_video[:, :1])
169
+ input_video_mask[:, :, :] = 255
170
+
171
+ return input_video, input_video_mask, None