jhj0517 commited on
Commit
7c3ff16
1 Parent(s): 0a9bdfb

initial commit

Browse files
Files changed (42) hide show
  1. __init__.py +0 -0
  2. dataset/dance_image.py +130 -0
  3. dataset/dance_video.py +150 -0
  4. models/attention.py +443 -0
  5. models/motion_module.py +388 -0
  6. models/mutual_self_attention.py +363 -0
  7. models/pose_guider.py +57 -0
  8. models/resnet.py +252 -0
  9. models/transformer_2d.py +395 -0
  10. models/transformer_3d.py +169 -0
  11. models/unet_2d_blocks.py +1074 -0
  12. models/unet_2d_condition.py +1307 -0
  13. models/unet_3d.py +675 -0
  14. models/unet_3d_blocks.py +871 -0
  15. musepose/__init__.py +0 -0
  16. musepose/dataset/dance_image.py +130 -0
  17. musepose/dataset/dance_video.py +150 -0
  18. musepose/models/attention.py +443 -0
  19. musepose/models/motion_module.py +388 -0
  20. musepose/models/mutual_self_attention.py +363 -0
  21. musepose/models/pose_guider.py +57 -0
  22. musepose/models/resnet.py +252 -0
  23. musepose/models/transformer_2d.py +395 -0
  24. musepose/models/transformer_3d.py +169 -0
  25. musepose/models/unet_2d_blocks.py +1074 -0
  26. musepose/models/unet_2d_condition.py +1307 -0
  27. musepose/models/unet_3d.py +675 -0
  28. musepose/models/unet_3d_blocks.py +871 -0
  29. musepose/pipelines/__init__.py +0 -0
  30. musepose/pipelines/context.py +76 -0
  31. musepose/pipelines/pipeline_pose2img.py +360 -0
  32. musepose/pipelines/pipeline_pose2vid.py +458 -0
  33. musepose/pipelines/pipeline_pose2vid_long.py +571 -0
  34. musepose/pipelines/utils.py +29 -0
  35. musepose/utils/util.py +133 -0
  36. pipelines/__init__.py +0 -0
  37. pipelines/context.py +76 -0
  38. pipelines/pipeline_pose2img.py +360 -0
  39. pipelines/pipeline_pose2vid.py +458 -0
  40. pipelines/pipeline_pose2vid_long.py +571 -0
  41. pipelines/utils.py +29 -0
  42. utils/util.py +133 -0
__init__.py ADDED
File without changes
dataset/dance_image.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from decord import VideoReader
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import CLIPImageProcessor
10
+
11
+
12
+ class HumanDanceDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ img_size,
16
+ img_scale=(1.0, 1.0),
17
+ img_ratio=(0.9, 1.0),
18
+ drop_ratio=0.1,
19
+ data_meta_paths=["./data/fahsion_meta.json"],
20
+ sample_margin=30,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.img_size = img_size
25
+ self.img_scale = img_scale
26
+ self.img_ratio = img_ratio
27
+ self.sample_margin = sample_margin
28
+
29
+ # -----
30
+ # vid_meta format:
31
+ # [{'video_path': , 'kps_path': , 'other':},
32
+ # {'video_path': , 'kps_path': , 'other':}]
33
+ # -----
34
+ vid_meta = []
35
+ for data_meta_path in data_meta_paths:
36
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
37
+ self.vid_meta = vid_meta
38
+
39
+ self.clip_image_processor = CLIPImageProcessor()
40
+
41
+ self.transform = transforms.Compose(
42
+ [
43
+ # transforms.RandomResizedCrop(
44
+ # self.img_size,
45
+ # scale=self.img_scale,
46
+ # ratio=self.img_ratio,
47
+ # interpolation=transforms.InterpolationMode.BILINEAR,
48
+ # ),
49
+ transforms.Resize(
50
+ self.img_size,
51
+ ),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5], [0.5]),
54
+ ]
55
+ )
56
+
57
+ self.cond_transform = transforms.Compose(
58
+ [
59
+ # transforms.RandomResizedCrop(
60
+ # self.img_size,
61
+ # scale=self.img_scale,
62
+ # ratio=self.img_ratio,
63
+ # interpolation=transforms.InterpolationMode.BILINEAR,
64
+ # ),
65
+ transforms.Resize(
66
+ self.img_size,
67
+ ),
68
+ transforms.ToTensor(),
69
+ ]
70
+ )
71
+
72
+ self.drop_ratio = drop_ratio
73
+
74
+ def augmentation(self, image, transform, state=None):
75
+ if state is not None:
76
+ torch.set_rng_state(state)
77
+ return transform(image)
78
+
79
+ def __getitem__(self, index):
80
+ video_meta = self.vid_meta[index]
81
+ video_path = video_meta["video_path"]
82
+ kps_path = video_meta["kps_path"]
83
+
84
+ video_reader = VideoReader(video_path)
85
+ kps_reader = VideoReader(kps_path)
86
+
87
+ assert len(video_reader) == len(
88
+ kps_reader
89
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90
+
91
+ video_length = len(video_reader)
92
+
93
+ margin = min(self.sample_margin, video_length)
94
+
95
+ ref_img_idx = random.randint(0, video_length - 1)
96
+ if ref_img_idx + margin < video_length:
97
+ tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
98
+ elif ref_img_idx - margin > 0:
99
+ tgt_img_idx = random.randint(0, ref_img_idx - margin)
100
+ else:
101
+ tgt_img_idx = random.randint(0, video_length - 1)
102
+
103
+ ref_img = video_reader[ref_img_idx]
104
+ ref_img_pil = Image.fromarray(ref_img.asnumpy())
105
+ tgt_img = video_reader[tgt_img_idx]
106
+ tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
107
+
108
+ tgt_pose = kps_reader[tgt_img_idx]
109
+ tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
110
+
111
+ state = torch.get_rng_state()
112
+ tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
113
+ tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
114
+ ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
115
+ clip_image = self.clip_image_processor(
116
+ images=ref_img_pil, return_tensors="pt"
117
+ ).pixel_values[0]
118
+
119
+ sample = dict(
120
+ video_dir=video_path,
121
+ img=tgt_img,
122
+ tgt_pose=tgt_pose_img,
123
+ ref_img=ref_img_vae,
124
+ clip_images=clip_image,
125
+ )
126
+
127
+ return sample
128
+
129
+ def __len__(self):
130
+ return len(self.vid_meta)
dataset/dance_video.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from decord import VideoReader
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from transformers import CLIPImageProcessor
13
+
14
+
15
+ class HumanDanceVideoDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ sample_rate,
19
+ n_sample_frames,
20
+ width,
21
+ height,
22
+ img_scale=(1.0, 1.0),
23
+ img_ratio=(0.9, 1.0),
24
+ drop_ratio=0.1,
25
+ data_meta_paths=["./data/fashion_meta.json"],
26
+ ):
27
+ super().__init__()
28
+ self.sample_rate = sample_rate
29
+ self.n_sample_frames = n_sample_frames
30
+ self.width = width
31
+ self.height = height
32
+ self.img_scale = img_scale
33
+ self.img_ratio = img_ratio
34
+
35
+ vid_meta = []
36
+ for data_meta_path in data_meta_paths:
37
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
38
+ self.vid_meta = vid_meta
39
+
40
+ self.clip_image_processor = CLIPImageProcessor()
41
+
42
+ self.pixel_transform = transforms.Compose(
43
+ [
44
+ # transforms.RandomResizedCrop(
45
+ # (height, width),
46
+ # scale=self.img_scale,
47
+ # ratio=self.img_ratio,
48
+ # interpolation=transforms.InterpolationMode.BILINEAR,
49
+ # ),
50
+ transforms.Resize(
51
+ (height, width),
52
+ ),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.5], [0.5]),
55
+ ]
56
+ )
57
+
58
+ self.cond_transform = transforms.Compose(
59
+ [
60
+ # transforms.RandomResizedCrop(
61
+ # (height, width),
62
+ # scale=self.img_scale,
63
+ # ratio=self.img_ratio,
64
+ # interpolation=transforms.InterpolationMode.BILINEAR,
65
+ # ),
66
+ transforms.Resize(
67
+ (height, width),
68
+ ),
69
+ transforms.ToTensor(),
70
+ ]
71
+ )
72
+
73
+ self.drop_ratio = drop_ratio
74
+
75
+ def augmentation(self, images, transform, state=None):
76
+ if state is not None:
77
+ torch.set_rng_state(state)
78
+ if isinstance(images, List):
79
+ transformed_images = [transform(img) for img in images]
80
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
81
+ else:
82
+ ret_tensor = transform(images) # (c, h, w)
83
+ return ret_tensor
84
+
85
+ def __getitem__(self, index):
86
+ video_meta = self.vid_meta[index]
87
+ video_path = video_meta["video_path"]
88
+ kps_path = video_meta["kps_path"]
89
+
90
+ video_reader = VideoReader(video_path)
91
+ kps_reader = VideoReader(kps_path)
92
+
93
+ assert len(video_reader) == len(
94
+ kps_reader
95
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
96
+
97
+ video_length = len(video_reader)
98
+ video_fps = video_reader.get_avg_fps()
99
+ # print("fps", video_fps)
100
+ if video_fps > 30: # 30-60
101
+ sample_rate = self.sample_rate*2
102
+ else:
103
+ sample_rate = self.sample_rate
104
+
105
+
106
+ clip_length = min(
107
+ video_length, (self.n_sample_frames - 1) * sample_rate + 1
108
+ )
109
+ start_idx = random.randint(0, video_length - clip_length)
110
+ batch_index = np.linspace(
111
+ start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
112
+ ).tolist()
113
+
114
+ # read frames and kps
115
+ vid_pil_image_list = []
116
+ pose_pil_image_list = []
117
+ for index in batch_index:
118
+ img = video_reader[index]
119
+ vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
120
+ img = kps_reader[index]
121
+ pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
122
+
123
+ ref_img_idx = random.randint(0, video_length - 1)
124
+ ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
125
+
126
+ # transform
127
+ state = torch.get_rng_state()
128
+ pixel_values_vid = self.augmentation(
129
+ vid_pil_image_list, self.pixel_transform, state
130
+ )
131
+ pixel_values_pose = self.augmentation(
132
+ pose_pil_image_list, self.cond_transform, state
133
+ )
134
+ pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
135
+ clip_ref_img = self.clip_image_processor(
136
+ images=ref_img, return_tensors="pt"
137
+ ).pixel_values[0]
138
+
139
+ sample = dict(
140
+ video_dir=video_path,
141
+ pixel_values_vid=pixel_values_vid,
142
+ pixel_values_pose=pixel_values_pose,
143
+ pixel_values_ref_img=pixel_values_ref_img,
144
+ clip_ref_img=clip_ref_img,
145
+ )
146
+
147
+ return sample
148
+
149
+ def __len__(self):
150
+ return len(self.vid_meta)
models/attention.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ ):
314
+ super().__init__()
315
+ self.only_cross_attention = only_cross_attention
316
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
317
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318
+ self.unet_use_temporal_attention = unet_use_temporal_attention
319
+
320
+ # SC-Attn
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ upcast_attention=upcast_attention,
328
+ )
329
+ self.norm1 = (
330
+ AdaLayerNorm(dim, num_embeds_ada_norm)
331
+ if self.use_ada_layer_norm
332
+ else nn.LayerNorm(dim)
333
+ )
334
+
335
+ # Cross-Attn
336
+ if cross_attention_dim is not None:
337
+ self.attn2 = Attention(
338
+ query_dim=dim,
339
+ cross_attention_dim=cross_attention_dim,
340
+ heads=num_attention_heads,
341
+ dim_head=attention_head_dim,
342
+ dropout=dropout,
343
+ bias=attention_bias,
344
+ upcast_attention=upcast_attention,
345
+ )
346
+ else:
347
+ self.attn2 = None
348
+
349
+ if cross_attention_dim is not None:
350
+ self.norm2 = (
351
+ AdaLayerNorm(dim, num_embeds_ada_norm)
352
+ if self.use_ada_layer_norm
353
+ else nn.LayerNorm(dim)
354
+ )
355
+ else:
356
+ self.norm2 = None
357
+
358
+ # Feed-forward
359
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
360
+ self.norm3 = nn.LayerNorm(dim)
361
+ self.use_ada_layer_norm_zero = False
362
+
363
+ # Temp-Attn
364
+ assert unet_use_temporal_attention is not None
365
+ if unet_use_temporal_attention:
366
+ self.attn_temp = Attention(
367
+ query_dim=dim,
368
+ heads=num_attention_heads,
369
+ dim_head=attention_head_dim,
370
+ dropout=dropout,
371
+ bias=attention_bias,
372
+ upcast_attention=upcast_attention,
373
+ )
374
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
375
+ self.norm_temp = (
376
+ AdaLayerNorm(dim, num_embeds_ada_norm)
377
+ if self.use_ada_layer_norm
378
+ else nn.LayerNorm(dim)
379
+ )
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states,
384
+ encoder_hidden_states=None,
385
+ timestep=None,
386
+ attention_mask=None,
387
+ video_length=None,
388
+ ):
389
+ norm_hidden_states = (
390
+ self.norm1(hidden_states, timestep)
391
+ if self.use_ada_layer_norm
392
+ else self.norm1(hidden_states)
393
+ )
394
+
395
+ if self.unet_use_cross_frame_attention:
396
+ hidden_states = (
397
+ self.attn1(
398
+ norm_hidden_states,
399
+ attention_mask=attention_mask,
400
+ video_length=video_length,
401
+ )
402
+ + hidden_states
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
407
+ + hidden_states
408
+ )
409
+
410
+ if self.attn2 is not None:
411
+ # Cross-Attention
412
+ norm_hidden_states = (
413
+ self.norm2(hidden_states, timestep)
414
+ if self.use_ada_layer_norm
415
+ else self.norm2(hidden_states)
416
+ )
417
+ hidden_states = (
418
+ self.attn2(
419
+ norm_hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ + hidden_states
424
+ )
425
+
426
+ # Feed-forward
427
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
428
+
429
+ # Temporal-Attention
430
+ if self.unet_use_temporal_attention:
431
+ d = hidden_states.shape[1]
432
+ hidden_states = rearrange(
433
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
434
+ )
435
+ norm_hidden_states = (
436
+ self.norm_temp(hidden_states, timestep)
437
+ if self.use_ada_layer_norm
438
+ else self.norm_temp(hidden_states)
439
+ )
440
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
441
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
442
+
443
+ return hidden_states
models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
models/mutual_self_attention.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from musepose.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ ):
104
+ if self.use_ada_layer_norm: # False
105
+ norm_hidden_states = self.norm1(hidden_states, timestep)
106
+ elif self.use_ada_layer_norm_zero:
107
+ (
108
+ norm_hidden_states,
109
+ gate_msa,
110
+ shift_mlp,
111
+ scale_mlp,
112
+ gate_mlp,
113
+ ) = self.norm1(
114
+ hidden_states,
115
+ timestep,
116
+ class_labels,
117
+ hidden_dtype=hidden_states.dtype,
118
+ )
119
+ else:
120
+ norm_hidden_states = self.norm1(hidden_states)
121
+
122
+ # 1. Self-Attention
123
+ # self.only_cross_attention = False
124
+ cross_attention_kwargs = (
125
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
126
+ )
127
+ if self.only_cross_attention:
128
+ attn_output = self.attn1(
129
+ norm_hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states
131
+ if self.only_cross_attention
132
+ else None,
133
+ attention_mask=attention_mask,
134
+ **cross_attention_kwargs,
135
+ )
136
+ else:
137
+ if MODE == "write":
138
+ self.bank.append(norm_hidden_states.clone())
139
+ attn_output = self.attn1(
140
+ norm_hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states
142
+ if self.only_cross_attention
143
+ else None,
144
+ attention_mask=attention_mask,
145
+ **cross_attention_kwargs,
146
+ )
147
+ if MODE == "read":
148
+ bank_fea = [
149
+ rearrange(
150
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
151
+ "b t l c -> (b t) l c",
152
+ )
153
+ for d in self.bank
154
+ ]
155
+ modify_norm_hidden_states = torch.cat(
156
+ [norm_hidden_states] + bank_fea, dim=1
157
+ )
158
+ hidden_states_uc = (
159
+ self.attn1(
160
+ norm_hidden_states,
161
+ encoder_hidden_states=modify_norm_hidden_states,
162
+ attention_mask=attention_mask,
163
+ )
164
+ + hidden_states
165
+ )
166
+ if do_classifier_free_guidance:
167
+ hidden_states_c = hidden_states_uc.clone()
168
+ _uc_mask = uc_mask.clone()
169
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
170
+ _uc_mask = (
171
+ torch.Tensor(
172
+ [1] * (hidden_states.shape[0] // 2)
173
+ + [0] * (hidden_states.shape[0] // 2)
174
+ )
175
+ .to(device)
176
+ .bool()
177
+ )
178
+ hidden_states_c[_uc_mask] = (
179
+ self.attn1(
180
+ norm_hidden_states[_uc_mask],
181
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
182
+ attention_mask=attention_mask,
183
+ )
184
+ + hidden_states[_uc_mask]
185
+ )
186
+ hidden_states = hidden_states_c.clone()
187
+ else:
188
+ hidden_states = hidden_states_uc
189
+
190
+ # self.bank.clear()
191
+ if self.attn2 is not None:
192
+ # Cross-Attention
193
+ norm_hidden_states = (
194
+ self.norm2(hidden_states, timestep)
195
+ if self.use_ada_layer_norm
196
+ else self.norm2(hidden_states)
197
+ )
198
+ hidden_states = (
199
+ self.attn2(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ attention_mask=attention_mask,
203
+ )
204
+ + hidden_states
205
+ )
206
+
207
+ # Feed-forward
208
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209
+
210
+ # Temporal-Attention
211
+ if self.unet_use_temporal_attention:
212
+ d = hidden_states.shape[1]
213
+ hidden_states = rearrange(
214
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
215
+ )
216
+ norm_hidden_states = (
217
+ self.norm_temp(hidden_states, timestep)
218
+ if self.use_ada_layer_norm
219
+ else self.norm_temp(hidden_states)
220
+ )
221
+ hidden_states = (
222
+ self.attn_temp(norm_hidden_states) + hidden_states
223
+ )
224
+ hidden_states = rearrange(
225
+ hidden_states, "(b d) f c -> (b f) d c", d=d
226
+ )
227
+
228
+ return hidden_states
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = attn_output + hidden_states
233
+
234
+ if self.attn2 is not None:
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+
241
+ # 2. Cross-Attention
242
+ attn_output = self.attn2(
243
+ norm_hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=encoder_attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+ hidden_states = attn_output + hidden_states
249
+
250
+ # 3. Feed-forward
251
+ norm_hidden_states = self.norm3(hidden_states)
252
+
253
+ if self.use_ada_layer_norm_zero:
254
+ norm_hidden_states = (
255
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256
+ )
257
+
258
+ ff_output = self.ff(norm_hidden_states)
259
+
260
+ if self.use_ada_layer_norm_zero:
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+
263
+ hidden_states = ff_output + hidden_states
264
+
265
+ return hidden_states
266
+
267
+ if self.reference_attn:
268
+ if self.fusion_blocks == "midup":
269
+ attn_modules = [
270
+ module
271
+ for module in (
272
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273
+ )
274
+ if isinstance(module, BasicTransformerBlock)
275
+ or isinstance(module, TemporalBasicTransformerBlock)
276
+ ]
277
+ elif self.fusion_blocks == "full":
278
+ attn_modules = [
279
+ module
280
+ for module in torch_dfs(self.unet)
281
+ if isinstance(module, BasicTransformerBlock)
282
+ or isinstance(module, TemporalBasicTransformerBlock)
283
+ ]
284
+ attn_modules = sorted(
285
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286
+ )
287
+
288
+ for i, module in enumerate(attn_modules):
289
+ module._original_inner_forward = module.forward
290
+ if isinstance(module, BasicTransformerBlock):
291
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
292
+ module, BasicTransformerBlock
293
+ )
294
+ if isinstance(module, TemporalBasicTransformerBlock):
295
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
296
+ module, TemporalBasicTransformerBlock
297
+ )
298
+
299
+ module.bank = []
300
+ module.attn_weight = float(i) / float(len(attn_modules))
301
+
302
+ def update(self, writer, dtype=torch.float16):
303
+ if self.reference_attn:
304
+ if self.fusion_blocks == "midup":
305
+ reader_attn_modules = [
306
+ module
307
+ for module in (
308
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309
+ )
310
+ if isinstance(module, TemporalBasicTransformerBlock)
311
+ ]
312
+ writer_attn_modules = [
313
+ module
314
+ for module in (
315
+ torch_dfs(writer.unet.mid_block)
316
+ + torch_dfs(writer.unet.up_blocks)
317
+ )
318
+ if isinstance(module, BasicTransformerBlock)
319
+ ]
320
+ elif self.fusion_blocks == "full":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in torch_dfs(self.unet)
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in torch_dfs(writer.unet)
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ reader_attn_modules = sorted(
332
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+ writer_attn_modules = sorted(
335
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336
+ )
337
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
338
+ r.bank = [v.clone().to(dtype) for v in w.bank]
339
+ # w.bank.clear()
340
+
341
+ def clear(self):
342
+ if self.reference_attn:
343
+ if self.fusion_blocks == "midup":
344
+ reader_attn_modules = [
345
+ module
346
+ for module in (
347
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348
+ )
349
+ if isinstance(module, BasicTransformerBlock)
350
+ or isinstance(module, TemporalBasicTransformerBlock)
351
+ ]
352
+ elif self.fusion_blocks == "full":
353
+ reader_attn_modules = [
354
+ module
355
+ for module in torch_dfs(self.unet)
356
+ if isinstance(module, BasicTransformerBlock)
357
+ or isinstance(module, TemporalBasicTransformerBlock)
358
+ ]
359
+ reader_attn_modules = sorted(
360
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361
+ )
362
+ for r in reader_attn_modules:
363
+ r.bank.clear()
models/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from musepose.models.motion_module import zero_module
9
+ from musepose.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227
+
228
+ if temb is not None and self.time_embedding_norm == "default":
229
+ hidden_states = hidden_states + temb
230
+
231
+ hidden_states = self.norm2(hidden_states)
232
+
233
+ if temb is not None and self.time_embedding_norm == "scale_shift":
234
+ scale, shift = torch.chunk(temb, 2, dim=1)
235
+ hidden_states = hidden_states * (1 + scale) + shift
236
+
237
+ hidden_states = self.nonlinearity(hidden_states)
238
+
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.conv2(hidden_states)
241
+
242
+ if self.conv_shortcut is not None:
243
+ input_tensor = self.conv_shortcut(input_tensor)
244
+
245
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246
+
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
models/transformer_2d.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.normalization import AdaLayerNormSingle
10
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
11
+ from torch import nn
12
+
13
+ from .attention import BasicTransformerBlock
14
+
15
+
16
+ @dataclass
17
+ class Transformer2DModelOutput(BaseOutput):
18
+ """
19
+ The output of [`Transformer2DModel`].
20
+
21
+ Args:
22
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
23
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
24
+ distributions for the unnoised latent pixels.
25
+ """
26
+
27
+ sample: torch.FloatTensor
28
+ ref_feature: torch.FloatTensor
29
+
30
+
31
+ class Transformer2DModel(ModelMixin, ConfigMixin):
32
+ """
33
+ A 2D Transformer model for image-like data.
34
+
35
+ Parameters:
36
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
38
+ in_channels (`int`, *optional*):
39
+ The number of channels in the input and output (specify if the input is **continuous**).
40
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
41
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
42
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
43
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
44
+ This is fixed during training since it is used to learn a number of position embeddings.
45
+ num_vector_embeds (`int`, *optional*):
46
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
47
+ Includes the class for the masked latent pixel.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
49
+ num_embeds_ada_norm ( `int`, *optional*):
50
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
51
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
52
+ added to the hidden states.
53
+
54
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
55
+ attention_bias (`bool`, *optional*):
56
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
57
+ """
58
+
59
+ _supports_gradient_checkpointing = True
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ num_attention_heads: int = 16,
65
+ attention_head_dim: int = 88,
66
+ in_channels: Optional[int] = None,
67
+ out_channels: Optional[int] = None,
68
+ num_layers: int = 1,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ cross_attention_dim: Optional[int] = None,
72
+ attention_bias: bool = False,
73
+ sample_size: Optional[int] = None,
74
+ num_vector_embeds: Optional[int] = None,
75
+ patch_size: Optional[int] = None,
76
+ activation_fn: str = "geglu",
77
+ num_embeds_ada_norm: Optional[int] = None,
78
+ use_linear_projection: bool = False,
79
+ only_cross_attention: bool = False,
80
+ double_self_attention: bool = False,
81
+ upcast_attention: bool = False,
82
+ norm_type: str = "layer_norm",
83
+ norm_elementwise_affine: bool = True,
84
+ norm_eps: float = 1e-5,
85
+ attention_type: str = "default",
86
+ caption_channels: int = None,
87
+ ):
88
+ super().__init__()
89
+ self.use_linear_projection = use_linear_projection
90
+ self.num_attention_heads = num_attention_heads
91
+ self.attention_head_dim = attention_head_dim
92
+ inner_dim = num_attention_heads * attention_head_dim
93
+
94
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
95
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
96
+
97
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
98
+ # Define whether input is continuous or discrete depending on configuration
99
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
100
+ self.is_input_vectorized = num_vector_embeds is not None
101
+ self.is_input_patches = in_channels is not None and patch_size is not None
102
+
103
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
104
+ deprecation_message = (
105
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
106
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
107
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
108
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
109
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
110
+ )
111
+ deprecate(
112
+ "norm_type!=num_embeds_ada_norm",
113
+ "1.0.0",
114
+ deprecation_message,
115
+ standard_warn=False,
116
+ )
117
+ norm_type = "ada_norm"
118
+
119
+ if self.is_input_continuous and self.is_input_vectorized:
120
+ raise ValueError(
121
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
122
+ " sure that either `in_channels` or `num_vector_embeds` is None."
123
+ )
124
+ elif self.is_input_vectorized and self.is_input_patches:
125
+ raise ValueError(
126
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
127
+ " sure that either `num_vector_embeds` or `num_patches` is None."
128
+ )
129
+ elif (
130
+ not self.is_input_continuous
131
+ and not self.is_input_vectorized
132
+ and not self.is_input_patches
133
+ ):
134
+ raise ValueError(
135
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
136
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
137
+ )
138
+
139
+ # 2. Define input layers
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(
143
+ num_groups=norm_num_groups,
144
+ num_channels=in_channels,
145
+ eps=1e-6,
146
+ affine=True,
147
+ )
148
+ if use_linear_projection:
149
+ self.proj_in = linear_cls(in_channels, inner_dim)
150
+ else:
151
+ self.proj_in = conv_cls(
152
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ # 3. Define transformers blocks
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ BasicTransformerBlock(
159
+ inner_dim,
160
+ num_attention_heads,
161
+ attention_head_dim,
162
+ dropout=dropout,
163
+ cross_attention_dim=cross_attention_dim,
164
+ activation_fn=activation_fn,
165
+ num_embeds_ada_norm=num_embeds_ada_norm,
166
+ attention_bias=attention_bias,
167
+ only_cross_attention=only_cross_attention,
168
+ double_self_attention=double_self_attention,
169
+ upcast_attention=upcast_attention,
170
+ norm_type=norm_type,
171
+ norm_elementwise_affine=norm_elementwise_affine,
172
+ norm_eps=norm_eps,
173
+ attention_type=attention_type,
174
+ )
175
+ for d in range(num_layers)
176
+ ]
177
+ )
178
+
179
+ # 4. Define output layers
180
+ self.out_channels = in_channels if out_channels is None else out_channels
181
+ # TODO: should use out_channels for continuous projections
182
+ if use_linear_projection:
183
+ self.proj_out = linear_cls(inner_dim, in_channels)
184
+ else:
185
+ self.proj_out = conv_cls(
186
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
187
+ )
188
+
189
+ # 5. PixArt-Alpha blocks.
190
+ self.adaln_single = None
191
+ self.use_additional_conditions = False
192
+ if norm_type == "ada_norm_single":
193
+ self.use_additional_conditions = self.config.sample_size == 128
194
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
195
+ # additional conditions until we find better name
196
+ self.adaln_single = AdaLayerNormSingle(
197
+ inner_dim, use_additional_conditions=self.use_additional_conditions
198
+ )
199
+
200
+ self.caption_projection = None
201
+ if caption_channels is not None:
202
+ self.caption_projection = CaptionProjection(
203
+ in_features=caption_channels, hidden_size=inner_dim
204
+ )
205
+
206
+ self.gradient_checkpointing = False
207
+
208
+ def _set_gradient_checkpointing(self, module, value=False):
209
+ if hasattr(module, "gradient_checkpointing"):
210
+ module.gradient_checkpointing = value
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ encoder_hidden_states: Optional[torch.Tensor] = None,
216
+ timestep: Optional[torch.LongTensor] = None,
217
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
218
+ class_labels: Optional[torch.LongTensor] = None,
219
+ cross_attention_kwargs: Dict[str, Any] = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ encoder_attention_mask: Optional[torch.Tensor] = None,
222
+ return_dict: bool = True,
223
+ ):
224
+ """
225
+ The [`Transformer2DModel`] forward method.
226
+
227
+ Args:
228
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
229
+ Input `hidden_states`.
230
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
+ self-attention.
233
+ timestep ( `torch.LongTensor`, *optional*):
234
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
237
+ `AdaLayerZeroNorm`.
238
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
239
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
240
+ `self.processor` in
241
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
242
+ attention_mask ( `torch.Tensor`, *optional*):
243
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
244
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
245
+ negative values to the attention scores corresponding to "discard" tokens.
246
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
247
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
248
+
249
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
250
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
251
+
252
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
253
+ above. This bias will be added to the cross-attention scores.
254
+ return_dict (`bool`, *optional*, defaults to `True`):
255
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
256
+ tuple.
257
+
258
+ Returns:
259
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
260
+ `tuple` where the first element is the sample tensor.
261
+ """
262
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
263
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
264
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
265
+ # expects mask of shape:
266
+ # [batch, key_tokens]
267
+ # adds singleton query_tokens dimension:
268
+ # [batch, 1, key_tokens]
269
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
270
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
271
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
272
+ if attention_mask is not None and attention_mask.ndim == 2:
273
+ # assume that mask is expressed as:
274
+ # (1 = keep, 0 = discard)
275
+ # convert mask into a bias that can be added to attention scores:
276
+ # (keep = +0, discard = -10000.0)
277
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
278
+ attention_mask = attention_mask.unsqueeze(1)
279
+
280
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
281
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
282
+ encoder_attention_mask = (
283
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
284
+ ) * -10000.0
285
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
286
+
287
+ # Retrieve lora scale.
288
+ lora_scale = (
289
+ cross_attention_kwargs.get("scale", 1.0)
290
+ if cross_attention_kwargs is not None
291
+ else 1.0
292
+ )
293
+
294
+ # 1. Input
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = (
301
+ self.proj_in(hidden_states, scale=lora_scale)
302
+ if not USE_PEFT_BACKEND
303
+ else self.proj_in(hidden_states)
304
+ )
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
307
+ batch, height * width, inner_dim
308
+ )
309
+ else:
310
+ inner_dim = hidden_states.shape[1]
311
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
312
+ batch, height * width, inner_dim
313
+ )
314
+ hidden_states = (
315
+ self.proj_in(hidden_states, scale=lora_scale)
316
+ if not USE_PEFT_BACKEND
317
+ else self.proj_in(hidden_states)
318
+ )
319
+
320
+ # 2. Blocks
321
+ if self.caption_projection is not None:
322
+ batch_size = hidden_states.shape[0]
323
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
324
+ encoder_hidden_states = encoder_hidden_states.view(
325
+ batch_size, -1, hidden_states.shape[-1]
326
+ )
327
+
328
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
329
+ for block in self.transformer_blocks:
330
+ if self.training and self.gradient_checkpointing:
331
+
332
+ def create_custom_forward(module, return_dict=None):
333
+ def custom_forward(*inputs):
334
+ if return_dict is not None:
335
+ return module(*inputs, return_dict=return_dict)
336
+ else:
337
+ return module(*inputs)
338
+
339
+ return custom_forward
340
+
341
+ ckpt_kwargs: Dict[str, Any] = (
342
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
343
+ )
344
+ hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ attention_mask,
348
+ encoder_hidden_states,
349
+ encoder_attention_mask,
350
+ timestep,
351
+ cross_attention_kwargs,
352
+ class_labels,
353
+ **ckpt_kwargs,
354
+ )
355
+ else:
356
+ hidden_states = block(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ encoder_attention_mask=encoder_attention_mask,
361
+ timestep=timestep,
362
+ cross_attention_kwargs=cross_attention_kwargs,
363
+ class_labels=class_labels,
364
+ )
365
+
366
+ # 3. Output
367
+ if self.is_input_continuous:
368
+ if not self.use_linear_projection:
369
+ hidden_states = (
370
+ hidden_states.reshape(batch, height, width, inner_dim)
371
+ .permute(0, 3, 1, 2)
372
+ .contiguous()
373
+ )
374
+ hidden_states = (
375
+ self.proj_out(hidden_states, scale=lora_scale)
376
+ if not USE_PEFT_BACKEND
377
+ else self.proj_out(hidden_states)
378
+ )
379
+ else:
380
+ hidden_states = (
381
+ self.proj_out(hidden_states, scale=lora_scale)
382
+ if not USE_PEFT_BACKEND
383
+ else self.proj_out(hidden_states)
384
+ )
385
+ hidden_states = (
386
+ hidden_states.reshape(batch, height, width, inner_dim)
387
+ .permute(0, 3, 1, 2)
388
+ .contiguous()
389
+ )
390
+
391
+ output = hidden_states + residual
392
+ if not return_dict:
393
+ return (output, ref_feature)
394
+
395
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
models/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
models/unet_2d_condition.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ Timesteps,
28
+ )
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ BaseOutput,
33
+ deprecate,
34
+ logging,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2D,
41
+ UNetMidBlock2DCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet2DConditionOutput(BaseOutput):
51
+ """
52
+ The output of [`UNet2DConditionModel`].
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
56
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor = None
60
+ ref_features: Tuple[torch.FloatTensor] = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, normalization and activation layers is skipped in post-processing.
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
105
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
145
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
146
+ *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = (
176
+ "UpBlock2D",
177
+ "CrossAttnUpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ ),
181
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
182
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
183
+ layers_per_block: Union[int, Tuple[int]] = 2,
184
+ downsample_padding: int = 1,
185
+ mid_block_scale_factor: float = 1,
186
+ dropout: float = 0.0,
187
+ act_fn: str = "silu",
188
+ norm_num_groups: Optional[int] = 32,
189
+ norm_eps: float = 1e-5,
190
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
191
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
192
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
193
+ encoder_hid_dim: Optional[int] = None,
194
+ encoder_hid_dim_type: Optional[str] = None,
195
+ attention_head_dim: Union[int, Tuple[int]] = 8,
196
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
197
+ dual_cross_attention: bool = False,
198
+ use_linear_projection: bool = False,
199
+ class_embed_type: Optional[str] = None,
200
+ addition_embed_type: Optional[str] = None,
201
+ addition_time_embed_dim: Optional[int] = None,
202
+ num_class_embeds: Optional[int] = None,
203
+ upcast_attention: bool = False,
204
+ resnet_time_scale_shift: str = "default",
205
+ resnet_skip_time_act: bool = False,
206
+ resnet_out_scale_factor: int = 1.0,
207
+ time_embedding_type: str = "positional",
208
+ time_embedding_dim: Optional[int] = None,
209
+ time_embedding_act_fn: Optional[str] = None,
210
+ timestep_post_act: Optional[str] = None,
211
+ time_cond_proj_dim: Optional[int] = None,
212
+ conv_in_kernel: int = 3,
213
+ conv_out_kernel: int = 3,
214
+ projection_class_embeddings_input_dim: Optional[int] = None,
215
+ attention_type: str = "default",
216
+ class_embeddings_concat: bool = False,
217
+ mid_block_only_cross_attention: Optional[bool] = None,
218
+ cross_attention_norm: Optional[str] = None,
219
+ addition_embed_type_num_heads=64,
220
+ ):
221
+ super().__init__()
222
+
223
+ self.sample_size = sample_size
224
+
225
+ if num_attention_heads is not None:
226
+ raise ValueError(
227
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
228
+ )
229
+
230
+ # If `num_attention_heads` is not defined (which is the case for most models)
231
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
232
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
233
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(down_block_types) != len(up_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
242
+ )
243
+
244
+ if len(block_out_channels) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(only_cross_attention, bool) and len(
250
+ only_cross_attention
251
+ ) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
257
+ down_block_types
258
+ ):
259
+ raise ValueError(
260
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
264
+ down_block_types
265
+ ):
266
+ raise ValueError(
267
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
271
+ down_block_types
272
+ ):
273
+ raise ValueError(
274
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
278
+ down_block_types
279
+ ):
280
+ raise ValueError(
281
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
282
+ )
283
+ if (
284
+ isinstance(transformer_layers_per_block, list)
285
+ and reverse_transformer_layers_per_block is None
286
+ ):
287
+ for layer_number_per_block in transformer_layers_per_block:
288
+ if isinstance(layer_number_per_block, list):
289
+ raise ValueError(
290
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
291
+ )
292
+
293
+ # input
294
+ conv_in_padding = (conv_in_kernel - 1) // 2
295
+ self.conv_in = nn.Conv2d(
296
+ in_channels,
297
+ block_out_channels[0],
298
+ kernel_size=conv_in_kernel,
299
+ padding=conv_in_padding,
300
+ )
301
+
302
+ # time
303
+ if time_embedding_type == "fourier":
304
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
305
+ if time_embed_dim % 2 != 0:
306
+ raise ValueError(
307
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
308
+ )
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2,
311
+ set_W_to_weight=False,
312
+ log=False,
313
+ flip_sin_to_cos=flip_sin_to_cos,
314
+ )
315
+ timestep_input_dim = time_embed_dim
316
+ elif time_embedding_type == "positional":
317
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
318
+
319
+ self.time_proj = Timesteps(
320
+ block_out_channels[0], flip_sin_to_cos, freq_shift
321
+ )
322
+ timestep_input_dim = block_out_channels[0]
323
+ else:
324
+ raise ValueError(
325
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
326
+ )
327
+
328
+ self.time_embedding = TimestepEmbedding(
329
+ timestep_input_dim,
330
+ time_embed_dim,
331
+ act_fn=act_fn,
332
+ post_act_fn=timestep_post_act,
333
+ cond_proj_dim=time_cond_proj_dim,
334
+ )
335
+
336
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
337
+ encoder_hid_dim_type = "text_proj"
338
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
339
+ logger.info(
340
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
341
+ )
342
+
343
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
344
+ raise ValueError(
345
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
346
+ )
347
+
348
+ if encoder_hid_dim_type == "text_proj":
349
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
350
+ elif encoder_hid_dim_type == "text_image_proj":
351
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
352
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
353
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
354
+ self.encoder_hid_proj = TextImageProjection(
355
+ text_embed_dim=encoder_hid_dim,
356
+ image_embed_dim=cross_attention_dim,
357
+ cross_attention_dim=cross_attention_dim,
358
+ )
359
+ elif encoder_hid_dim_type == "image_proj":
360
+ # Kandinsky 2.2
361
+ self.encoder_hid_proj = ImageProjection(
362
+ image_embed_dim=encoder_hid_dim,
363
+ cross_attention_dim=cross_attention_dim,
364
+ )
365
+ elif encoder_hid_dim_type is not None:
366
+ raise ValueError(
367
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
368
+ )
369
+ else:
370
+ self.encoder_hid_proj = None
371
+
372
+ # class embedding
373
+ if class_embed_type is None and num_class_embeds is not None:
374
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
375
+ elif class_embed_type == "timestep":
376
+ self.class_embedding = TimestepEmbedding(
377
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
378
+ )
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ elif class_embed_type == "simple_projection":
397
+ if projection_class_embeddings_input_dim is None:
398
+ raise ValueError(
399
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
400
+ )
401
+ self.class_embedding = nn.Linear(
402
+ projection_class_embeddings_input_dim, time_embed_dim
403
+ )
404
+ else:
405
+ self.class_embedding = None
406
+
407
+ if addition_embed_type == "text":
408
+ if encoder_hid_dim is not None:
409
+ text_time_embedding_from_dim = encoder_hid_dim
410
+ else:
411
+ text_time_embedding_from_dim = cross_attention_dim
412
+
413
+ self.add_embedding = TextTimeEmbedding(
414
+ text_time_embedding_from_dim,
415
+ time_embed_dim,
416
+ num_heads=addition_embed_type_num_heads,
417
+ )
418
+ elif addition_embed_type == "text_image":
419
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
420
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
421
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
422
+ self.add_embedding = TextImageTimeEmbedding(
423
+ text_embed_dim=cross_attention_dim,
424
+ image_embed_dim=cross_attention_dim,
425
+ time_embed_dim=time_embed_dim,
426
+ )
427
+ elif addition_embed_type == "text_time":
428
+ self.add_time_proj = Timesteps(
429
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
430
+ )
431
+ self.add_embedding = TimestepEmbedding(
432
+ projection_class_embeddings_input_dim, time_embed_dim
433
+ )
434
+ elif addition_embed_type == "image":
435
+ # Kandinsky 2.2
436
+ self.add_embedding = ImageTimeEmbedding(
437
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
438
+ )
439
+ elif addition_embed_type == "image_hint":
440
+ # Kandinsky 2.2 ControlNet
441
+ self.add_embedding = ImageHintTimeEmbedding(
442
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
443
+ )
444
+ elif addition_embed_type is not None:
445
+ raise ValueError(
446
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
447
+ )
448
+
449
+ if time_embedding_act_fn is None:
450
+ self.time_embed_act = None
451
+ else:
452
+ self.time_embed_act = get_activation(time_embedding_act_fn)
453
+
454
+ self.down_blocks = nn.ModuleList([])
455
+ self.up_blocks = nn.ModuleList([])
456
+
457
+ if isinstance(only_cross_attention, bool):
458
+ if mid_block_only_cross_attention is None:
459
+ mid_block_only_cross_attention = only_cross_attention
460
+
461
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
462
+
463
+ if mid_block_only_cross_attention is None:
464
+ mid_block_only_cross_attention = False
465
+
466
+ if isinstance(num_attention_heads, int):
467
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
468
+
469
+ if isinstance(attention_head_dim, int):
470
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
471
+
472
+ if isinstance(cross_attention_dim, int):
473
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
474
+
475
+ if isinstance(layers_per_block, int):
476
+ layers_per_block = [layers_per_block] * len(down_block_types)
477
+
478
+ if isinstance(transformer_layers_per_block, int):
479
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
480
+ down_block_types
481
+ )
482
+
483
+ if class_embeddings_concat:
484
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
485
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
486
+ # regular time embeddings
487
+ blocks_time_embed_dim = time_embed_dim * 2
488
+ else:
489
+ blocks_time_embed_dim = time_embed_dim
490
+
491
+ # down
492
+ output_channel = block_out_channels[0]
493
+ for i, down_block_type in enumerate(down_block_types):
494
+ input_channel = output_channel
495
+ output_channel = block_out_channels[i]
496
+ is_final_block = i == len(block_out_channels) - 1
497
+
498
+ down_block = get_down_block(
499
+ down_block_type,
500
+ num_layers=layers_per_block[i],
501
+ transformer_layers_per_block=transformer_layers_per_block[i],
502
+ in_channels=input_channel,
503
+ out_channels=output_channel,
504
+ temb_channels=blocks_time_embed_dim,
505
+ add_downsample=not is_final_block,
506
+ resnet_eps=norm_eps,
507
+ resnet_act_fn=act_fn,
508
+ resnet_groups=norm_num_groups,
509
+ cross_attention_dim=cross_attention_dim[i],
510
+ num_attention_heads=num_attention_heads[i],
511
+ downsample_padding=downsample_padding,
512
+ dual_cross_attention=dual_cross_attention,
513
+ use_linear_projection=use_linear_projection,
514
+ only_cross_attention=only_cross_attention[i],
515
+ upcast_attention=upcast_attention,
516
+ resnet_time_scale_shift=resnet_time_scale_shift,
517
+ attention_type=attention_type,
518
+ resnet_skip_time_act=resnet_skip_time_act,
519
+ resnet_out_scale_factor=resnet_out_scale_factor,
520
+ cross_attention_norm=cross_attention_norm,
521
+ attention_head_dim=attention_head_dim[i]
522
+ if attention_head_dim[i] is not None
523
+ else output_channel,
524
+ dropout=dropout,
525
+ )
526
+ self.down_blocks.append(down_block)
527
+
528
+ # mid
529
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
530
+ self.mid_block = UNetMidBlock2DCrossAttn(
531
+ transformer_layers_per_block=transformer_layers_per_block[-1],
532
+ in_channels=block_out_channels[-1],
533
+ temb_channels=blocks_time_embed_dim,
534
+ dropout=dropout,
535
+ resnet_eps=norm_eps,
536
+ resnet_act_fn=act_fn,
537
+ output_scale_factor=mid_block_scale_factor,
538
+ resnet_time_scale_shift=resnet_time_scale_shift,
539
+ cross_attention_dim=cross_attention_dim[-1],
540
+ num_attention_heads=num_attention_heads[-1],
541
+ resnet_groups=norm_num_groups,
542
+ dual_cross_attention=dual_cross_attention,
543
+ use_linear_projection=use_linear_projection,
544
+ upcast_attention=upcast_attention,
545
+ attention_type=attention_type,
546
+ )
547
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
548
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
549
+ elif mid_block_type == "UNetMidBlock2D":
550
+ self.mid_block = UNetMidBlock2D(
551
+ in_channels=block_out_channels[-1],
552
+ temb_channels=blocks_time_embed_dim,
553
+ dropout=dropout,
554
+ num_layers=0,
555
+ resnet_eps=norm_eps,
556
+ resnet_act_fn=act_fn,
557
+ output_scale_factor=mid_block_scale_factor,
558
+ resnet_groups=norm_num_groups,
559
+ resnet_time_scale_shift=resnet_time_scale_shift,
560
+ add_attention=False,
561
+ )
562
+ elif mid_block_type is None:
563
+ self.mid_block = None
564
+ else:
565
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
566
+
567
+ # count how many layers upsample the images
568
+ self.num_upsamplers = 0
569
+
570
+ # up
571
+ reversed_block_out_channels = list(reversed(block_out_channels))
572
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
573
+ reversed_layers_per_block = list(reversed(layers_per_block))
574
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
575
+ reversed_transformer_layers_per_block = (
576
+ list(reversed(transformer_layers_per_block))
577
+ if reverse_transformer_layers_per_block is None
578
+ else reverse_transformer_layers_per_block
579
+ )
580
+ only_cross_attention = list(reversed(only_cross_attention))
581
+
582
+ output_channel = reversed_block_out_channels[0]
583
+ for i, up_block_type in enumerate(up_block_types):
584
+ is_final_block = i == len(block_out_channels) - 1
585
+
586
+ prev_output_channel = output_channel
587
+ output_channel = reversed_block_out_channels[i]
588
+ input_channel = reversed_block_out_channels[
589
+ min(i + 1, len(block_out_channels) - 1)
590
+ ]
591
+
592
+ # add upsample block for all BUT final layer
593
+ if not is_final_block:
594
+ add_upsample = True
595
+ self.num_upsamplers += 1
596
+ else:
597
+ add_upsample = False
598
+
599
+ up_block = get_up_block(
600
+ up_block_type,
601
+ num_layers=reversed_layers_per_block[i] + 1,
602
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
603
+ in_channels=input_channel,
604
+ out_channels=output_channel,
605
+ prev_output_channel=prev_output_channel,
606
+ temb_channels=blocks_time_embed_dim,
607
+ add_upsample=add_upsample,
608
+ resnet_eps=norm_eps,
609
+ resnet_act_fn=act_fn,
610
+ resolution_idx=i,
611
+ resnet_groups=norm_num_groups,
612
+ cross_attention_dim=reversed_cross_attention_dim[i],
613
+ num_attention_heads=reversed_num_attention_heads[i],
614
+ dual_cross_attention=dual_cross_attention,
615
+ use_linear_projection=use_linear_projection,
616
+ only_cross_attention=only_cross_attention[i],
617
+ upcast_attention=upcast_attention,
618
+ resnet_time_scale_shift=resnet_time_scale_shift,
619
+ attention_type=attention_type,
620
+ resnet_skip_time_act=resnet_skip_time_act,
621
+ resnet_out_scale_factor=resnet_out_scale_factor,
622
+ cross_attention_norm=cross_attention_norm,
623
+ attention_head_dim=attention_head_dim[i]
624
+ if attention_head_dim[i] is not None
625
+ else output_channel,
626
+ dropout=dropout,
627
+ )
628
+ self.up_blocks.append(up_block)
629
+ prev_output_channel = output_channel
630
+
631
+ # out
632
+ if norm_num_groups is not None:
633
+ self.conv_norm_out = nn.GroupNorm(
634
+ num_channels=block_out_channels[0],
635
+ num_groups=norm_num_groups,
636
+ eps=norm_eps,
637
+ )
638
+
639
+ self.conv_act = get_activation(act_fn)
640
+
641
+ else:
642
+ self.conv_norm_out = None
643
+ self.conv_act = None
644
+ self.conv_norm_out = None
645
+
646
+ conv_out_padding = (conv_out_kernel - 1) // 2
647
+ # self.conv_out = nn.Conv2d(
648
+ # block_out_channels[0],
649
+ # out_channels,
650
+ # kernel_size=conv_out_kernel,
651
+ # padding=conv_out_padding,
652
+ # )
653
+
654
+ if attention_type in ["gated", "gated-text-image"]:
655
+ positive_len = 768
656
+ if isinstance(cross_attention_dim, int):
657
+ positive_len = cross_attention_dim
658
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
659
+ cross_attention_dim, list
660
+ ):
661
+ positive_len = cross_attention_dim[0]
662
+
663
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
664
+ self.position_net = PositionNet(
665
+ positive_len=positive_len,
666
+ out_dim=cross_attention_dim,
667
+ feature_type=feature_type,
668
+ )
669
+
670
+ @property
671
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
672
+ r"""
673
+ Returns:
674
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
675
+ indexed by its weight name.
676
+ """
677
+ # set recursively
678
+ processors = {}
679
+
680
+ def fn_recursive_add_processors(
681
+ name: str,
682
+ module: torch.nn.Module,
683
+ processors: Dict[str, AttentionProcessor],
684
+ ):
685
+ if hasattr(module, "get_processor"):
686
+ processors[f"{name}.processor"] = module.get_processor(
687
+ return_deprecated_lora=True
688
+ )
689
+
690
+ for sub_name, child in module.named_children():
691
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
692
+
693
+ return processors
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_add_processors(name, module, processors)
697
+
698
+ return processors
699
+
700
+ def set_attn_processor(
701
+ self,
702
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
703
+ _remove_lora=False,
704
+ ):
705
+ r"""
706
+ Sets the attention processor to use to compute attention.
707
+
708
+ Parameters:
709
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
710
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
711
+ for **all** `Attention` layers.
712
+
713
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
714
+ processor. This is strongly recommended when setting trainable attention processors.
715
+
716
+ """
717
+ count = len(self.attn_processors.keys())
718
+
719
+ if isinstance(processor, dict) and len(processor) != count:
720
+ raise ValueError(
721
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
722
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
723
+ )
724
+
725
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
726
+ if hasattr(module, "set_processor"):
727
+ if not isinstance(processor, dict):
728
+ module.set_processor(processor, _remove_lora=_remove_lora)
729
+ else:
730
+ module.set_processor(
731
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
732
+ )
733
+
734
+ for sub_name, child in module.named_children():
735
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
736
+
737
+ for name, module in self.named_children():
738
+ fn_recursive_attn_processor(name, module, processor)
739
+
740
+ def set_default_attn_processor(self):
741
+ """
742
+ Disables custom attention processors and sets the default attention implementation.
743
+ """
744
+ if all(
745
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
746
+ for proc in self.attn_processors.values()
747
+ ):
748
+ processor = AttnAddedKVProcessor()
749
+ elif all(
750
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
751
+ for proc in self.attn_processors.values()
752
+ ):
753
+ processor = AttnProcessor()
754
+ else:
755
+ raise ValueError(
756
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
757
+ )
758
+
759
+ self.set_attn_processor(processor, _remove_lora=True)
760
+
761
+ def set_attention_slice(self, slice_size):
762
+ r"""
763
+ Enable sliced attention computation.
764
+
765
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
766
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
767
+
768
+ Args:
769
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
770
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
771
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
772
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
773
+ must be a multiple of `slice_size`.
774
+ """
775
+ sliceable_head_dims = []
776
+
777
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
778
+ if hasattr(module, "set_attention_slice"):
779
+ sliceable_head_dims.append(module.sliceable_head_dim)
780
+
781
+ for child in module.children():
782
+ fn_recursive_retrieve_sliceable_dims(child)
783
+
784
+ # retrieve number of attention layers
785
+ for module in self.children():
786
+ fn_recursive_retrieve_sliceable_dims(module)
787
+
788
+ num_sliceable_layers = len(sliceable_head_dims)
789
+
790
+ if slice_size == "auto":
791
+ # half the attention head size is usually a good trade-off between
792
+ # speed and memory
793
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
794
+ elif slice_size == "max":
795
+ # make smallest slice possible
796
+ slice_size = num_sliceable_layers * [1]
797
+
798
+ slice_size = (
799
+ num_sliceable_layers * [slice_size]
800
+ if not isinstance(slice_size, list)
801
+ else slice_size
802
+ )
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(
820
+ module: torch.nn.Module, slice_size: List[int]
821
+ ):
822
+ if hasattr(module, "set_attention_slice"):
823
+ module.set_attention_slice(slice_size.pop())
824
+
825
+ for child in module.children():
826
+ fn_recursive_set_attention_slice(child, slice_size)
827
+
828
+ reversed_slice_size = list(reversed(slice_size))
829
+ for module in self.children():
830
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
831
+
832
+ def _set_gradient_checkpointing(self, module, value=False):
833
+ if hasattr(module, "gradient_checkpointing"):
834
+ module.gradient_checkpointing = value
835
+
836
+ def enable_freeu(self, s1, s2, b1, b2):
837
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
838
+
839
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
840
+
841
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
842
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
843
+
844
+ Args:
845
+ s1 (`float`):
846
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
847
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
848
+ s2 (`float`):
849
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
850
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
851
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
852
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
853
+ """
854
+ for i, upsample_block in enumerate(self.up_blocks):
855
+ setattr(upsample_block, "s1", s1)
856
+ setattr(upsample_block, "s2", s2)
857
+ setattr(upsample_block, "b1", b1)
858
+ setattr(upsample_block, "b2", b2)
859
+
860
+ def disable_freeu(self):
861
+ """Disables the FreeU mechanism."""
862
+ freeu_keys = {"s1", "s2", "b1", "b2"}
863
+ for i, upsample_block in enumerate(self.up_blocks):
864
+ for k in freeu_keys:
865
+ if (
866
+ hasattr(upsample_block, k)
867
+ or getattr(upsample_block, k, None) is not None
868
+ ):
869
+ setattr(upsample_block, k, None)
870
+
871
+ def forward(
872
+ self,
873
+ sample: torch.FloatTensor,
874
+ timestep: Union[torch.Tensor, float, int],
875
+ encoder_hidden_states: torch.Tensor,
876
+ class_labels: Optional[torch.Tensor] = None,
877
+ timestep_cond: Optional[torch.Tensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
880
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
881
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
882
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
883
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
884
+ encoder_attention_mask: Optional[torch.Tensor] = None,
885
+ return_dict: bool = True,
886
+ ) -> Union[UNet2DConditionOutput, Tuple]:
887
+ r"""
888
+ The [`UNet2DConditionModel`] forward method.
889
+
890
+ Args:
891
+ sample (`torch.FloatTensor`):
892
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
893
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
894
+ encoder_hidden_states (`torch.FloatTensor`):
895
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
896
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
897
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
898
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
899
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
900
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
901
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
902
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
903
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
904
+ negative values to the attention scores corresponding to "discard" tokens.
905
+ cross_attention_kwargs (`dict`, *optional*):
906
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
907
+ `self.processor` in
908
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
909
+ added_cond_kwargs: (`dict`, *optional*):
910
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
911
+ are passed along to the UNet blocks.
912
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
913
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
914
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
915
+ A tensor that if specified is added to the residual of the middle unet block.
916
+ encoder_attention_mask (`torch.Tensor`):
917
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
918
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
919
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
920
+ return_dict (`bool`, *optional*, defaults to `True`):
921
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
922
+ tuple.
923
+ cross_attention_kwargs (`dict`, *optional*):
924
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
925
+ added_cond_kwargs: (`dict`, *optional*):
926
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
927
+ are passed along to the UNet blocks.
928
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
929
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
930
+ example from ControlNet side model(s)
931
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
932
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
933
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
934
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
935
+
936
+ Returns:
937
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
938
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
939
+ a `tuple` is returned where the first element is the sample tensor.
940
+ """
941
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
942
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
943
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
944
+ # on the fly if necessary.
945
+ default_overall_up_factor = 2**self.num_upsamplers
946
+
947
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
948
+ forward_upsample_size = False
949
+ upsample_size = None
950
+
951
+ for dim in sample.shape[-2:]:
952
+ if dim % default_overall_up_factor != 0:
953
+ # Forward upsample size to force interpolation output size.
954
+ forward_upsample_size = True
955
+ break
956
+
957
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
958
+ # expects mask of shape:
959
+ # [batch, key_tokens]
960
+ # adds singleton query_tokens dimension:
961
+ # [batch, 1, key_tokens]
962
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
963
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
964
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
965
+ if attention_mask is not None:
966
+ # assume that mask is expressed as:
967
+ # (1 = keep, 0 = discard)
968
+ # convert mask into a bias that can be added to attention scores:
969
+ # (keep = +0, discard = -10000.0)
970
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
971
+ attention_mask = attention_mask.unsqueeze(1)
972
+
973
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
974
+ if encoder_attention_mask is not None:
975
+ encoder_attention_mask = (
976
+ 1 - encoder_attention_mask.to(sample.dtype)
977
+ ) * -10000.0
978
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
979
+
980
+ # 0. center input if necessary
981
+ if self.config.center_input_sample:
982
+ sample = 2 * sample - 1.0
983
+
984
+ # 1. time
985
+ timesteps = timestep
986
+ if not torch.is_tensor(timesteps):
987
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
988
+ # This would be a good case for the `match` statement (Python 3.10+)
989
+ is_mps = sample.device.type == "mps"
990
+ if isinstance(timestep, float):
991
+ dtype = torch.float32 if is_mps else torch.float64
992
+ else:
993
+ dtype = torch.int32 if is_mps else torch.int64
994
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
995
+ elif len(timesteps.shape) == 0:
996
+ timesteps = timesteps[None].to(sample.device)
997
+
998
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
999
+ timesteps = timesteps.expand(sample.shape[0])
1000
+
1001
+ t_emb = self.time_proj(timesteps)
1002
+
1003
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1004
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1005
+ # there might be better ways to encapsulate this.
1006
+ t_emb = t_emb.to(dtype=sample.dtype)
1007
+
1008
+ emb = self.time_embedding(t_emb, timestep_cond)
1009
+ aug_emb = None
1010
+
1011
+ if self.class_embedding is not None:
1012
+ if class_labels is None:
1013
+ raise ValueError(
1014
+ "class_labels should be provided when num_class_embeds > 0"
1015
+ )
1016
+
1017
+ if self.config.class_embed_type == "timestep":
1018
+ class_labels = self.time_proj(class_labels)
1019
+
1020
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1021
+ # there might be better ways to encapsulate this.
1022
+ class_labels = class_labels.to(dtype=sample.dtype)
1023
+
1024
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1025
+
1026
+ if self.config.class_embeddings_concat:
1027
+ emb = torch.cat([emb, class_emb], dim=-1)
1028
+ else:
1029
+ emb = emb + class_emb
1030
+
1031
+ if self.config.addition_embed_type == "text":
1032
+ aug_emb = self.add_embedding(encoder_hidden_states)
1033
+ elif self.config.addition_embed_type == "text_image":
1034
+ # Kandinsky 2.1 - style
1035
+ if "image_embeds" not in added_cond_kwargs:
1036
+ raise ValueError(
1037
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1038
+ )
1039
+
1040
+ image_embs = added_cond_kwargs.get("image_embeds")
1041
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1042
+ aug_emb = self.add_embedding(text_embs, image_embs)
1043
+ elif self.config.addition_embed_type == "text_time":
1044
+ # SDXL - style
1045
+ if "text_embeds" not in added_cond_kwargs:
1046
+ raise ValueError(
1047
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1048
+ )
1049
+ text_embeds = added_cond_kwargs.get("text_embeds")
1050
+ if "time_ids" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ time_ids = added_cond_kwargs.get("time_ids")
1055
+ time_embeds = self.add_time_proj(time_ids.flatten())
1056
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1057
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1058
+ add_embeds = add_embeds.to(emb.dtype)
1059
+ aug_emb = self.add_embedding(add_embeds)
1060
+ elif self.config.addition_embed_type == "image":
1061
+ # Kandinsky 2.2 - style
1062
+ if "image_embeds" not in added_cond_kwargs:
1063
+ raise ValueError(
1064
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1065
+ )
1066
+ image_embs = added_cond_kwargs.get("image_embeds")
1067
+ aug_emb = self.add_embedding(image_embs)
1068
+ elif self.config.addition_embed_type == "image_hint":
1069
+ # Kandinsky 2.2 - style
1070
+ if (
1071
+ "image_embeds" not in added_cond_kwargs
1072
+ or "hint" not in added_cond_kwargs
1073
+ ):
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if (
1088
+ self.encoder_hid_proj is not None
1089
+ and self.config.encoder_hid_dim_type == "text_proj"
1090
+ ):
1091
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1092
+ elif (
1093
+ self.encoder_hid_proj is not None
1094
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1095
+ ):
1096
+ # Kadinsky 2.1 - style
1097
+ if "image_embeds" not in added_cond_kwargs:
1098
+ raise ValueError(
1099
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1100
+ )
1101
+
1102
+ image_embeds = added_cond_kwargs.get("image_embeds")
1103
+ encoder_hidden_states = self.encoder_hid_proj(
1104
+ encoder_hidden_states, image_embeds
1105
+ )
1106
+ elif (
1107
+ self.encoder_hid_proj is not None
1108
+ and self.config.encoder_hid_dim_type == "image_proj"
1109
+ ):
1110
+ # Kandinsky 2.2 - style
1111
+ if "image_embeds" not in added_cond_kwargs:
1112
+ raise ValueError(
1113
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1114
+ )
1115
+ image_embeds = added_cond_kwargs.get("image_embeds")
1116
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1117
+ elif (
1118
+ self.encoder_hid_proj is not None
1119
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1120
+ ):
1121
+ if "image_embeds" not in added_cond_kwargs:
1122
+ raise ValueError(
1123
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1124
+ )
1125
+ image_embeds = added_cond_kwargs.get("image_embeds")
1126
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1127
+ encoder_hidden_states.dtype
1128
+ )
1129
+ encoder_hidden_states = torch.cat(
1130
+ [encoder_hidden_states, image_embeds], dim=1
1131
+ )
1132
+
1133
+ # 2. pre-process
1134
+ sample = self.conv_in(sample)
1135
+
1136
+ # 2.5 GLIGEN position net
1137
+ if (
1138
+ cross_attention_kwargs is not None
1139
+ and cross_attention_kwargs.get("gligen", None) is not None
1140
+ ):
1141
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1142
+ gligen_args = cross_attention_kwargs.pop("gligen")
1143
+ cross_attention_kwargs["gligen"] = {
1144
+ "objs": self.position_net(**gligen_args)
1145
+ }
1146
+
1147
+ # 3. down
1148
+ lora_scale = (
1149
+ cross_attention_kwargs.get("scale", 1.0)
1150
+ if cross_attention_kwargs is not None
1151
+ else 1.0
1152
+ )
1153
+ if USE_PEFT_BACKEND:
1154
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1155
+ scale_lora_layers(self, lora_scale)
1156
+
1157
+ is_controlnet = (
1158
+ mid_block_additional_residual is not None
1159
+ and down_block_additional_residuals is not None
1160
+ )
1161
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1162
+ is_adapter = down_intrablock_additional_residuals is not None
1163
+ # maintain backward compatibility for legacy usage, where
1164
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1165
+ # but can only use one or the other
1166
+ if (
1167
+ not is_adapter
1168
+ and mid_block_additional_residual is None
1169
+ and down_block_additional_residuals is not None
1170
+ ):
1171
+ deprecate(
1172
+ "T2I should not use down_block_additional_residuals",
1173
+ "1.3.0",
1174
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1175
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1176
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1177
+ standard_warn=False,
1178
+ )
1179
+ down_intrablock_additional_residuals = down_block_additional_residuals
1180
+ is_adapter = True
1181
+
1182
+ down_block_res_samples = (sample,)
1183
+ tot_referece_features = ()
1184
+ for downsample_block in self.down_blocks:
1185
+ if (
1186
+ hasattr(downsample_block, "has_cross_attention")
1187
+ and downsample_block.has_cross_attention
1188
+ ):
1189
+ # For t2i-adapter CrossAttnDownBlock2D
1190
+ additional_residuals = {}
1191
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1192
+ additional_residuals[
1193
+ "additional_residuals"
1194
+ ] = down_intrablock_additional_residuals.pop(0)
1195
+
1196
+ sample, res_samples = downsample_block(
1197
+ hidden_states=sample,
1198
+ temb=emb,
1199
+ encoder_hidden_states=encoder_hidden_states,
1200
+ attention_mask=attention_mask,
1201
+ cross_attention_kwargs=cross_attention_kwargs,
1202
+ encoder_attention_mask=encoder_attention_mask,
1203
+ **additional_residuals,
1204
+ )
1205
+ else:
1206
+ sample, res_samples = downsample_block(
1207
+ hidden_states=sample, temb=emb, scale=lora_scale
1208
+ )
1209
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1210
+ sample += down_intrablock_additional_residuals.pop(0)
1211
+
1212
+ down_block_res_samples += res_samples
1213
+
1214
+ if is_controlnet:
1215
+ new_down_block_res_samples = ()
1216
+
1217
+ for down_block_res_sample, down_block_additional_residual in zip(
1218
+ down_block_res_samples, down_block_additional_residuals
1219
+ ):
1220
+ down_block_res_sample = (
1221
+ down_block_res_sample + down_block_additional_residual
1222
+ )
1223
+ new_down_block_res_samples = new_down_block_res_samples + (
1224
+ down_block_res_sample,
1225
+ )
1226
+
1227
+ down_block_res_samples = new_down_block_res_samples
1228
+
1229
+ # 4. mid
1230
+ if self.mid_block is not None:
1231
+ if (
1232
+ hasattr(self.mid_block, "has_cross_attention")
1233
+ and self.mid_block.has_cross_attention
1234
+ ):
1235
+ sample = self.mid_block(
1236
+ sample,
1237
+ emb,
1238
+ encoder_hidden_states=encoder_hidden_states,
1239
+ attention_mask=attention_mask,
1240
+ cross_attention_kwargs=cross_attention_kwargs,
1241
+ encoder_attention_mask=encoder_attention_mask,
1242
+ )
1243
+ else:
1244
+ sample = self.mid_block(sample, emb)
1245
+
1246
+ # To support T2I-Adapter-XL
1247
+ if (
1248
+ is_adapter
1249
+ and len(down_intrablock_additional_residuals) > 0
1250
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1251
+ ):
1252
+ sample += down_intrablock_additional_residuals.pop(0)
1253
+
1254
+ if is_controlnet:
1255
+ sample = sample + mid_block_additional_residual
1256
+
1257
+ # 5. up
1258
+ for i, upsample_block in enumerate(self.up_blocks):
1259
+ is_final_block = i == len(self.up_blocks) - 1
1260
+
1261
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1262
+ down_block_res_samples = down_block_res_samples[
1263
+ : -len(upsample_block.resnets)
1264
+ ]
1265
+
1266
+ # if we have not reached the final block and need to forward the
1267
+ # upsample size, we do it here
1268
+ if not is_final_block and forward_upsample_size:
1269
+ upsample_size = down_block_res_samples[-1].shape[2:]
1270
+
1271
+ if (
1272
+ hasattr(upsample_block, "has_cross_attention")
1273
+ and upsample_block.has_cross_attention
1274
+ ):
1275
+ sample = upsample_block(
1276
+ hidden_states=sample,
1277
+ temb=emb,
1278
+ res_hidden_states_tuple=res_samples,
1279
+ encoder_hidden_states=encoder_hidden_states,
1280
+ cross_attention_kwargs=cross_attention_kwargs,
1281
+ upsample_size=upsample_size,
1282
+ attention_mask=attention_mask,
1283
+ encoder_attention_mask=encoder_attention_mask,
1284
+ )
1285
+ else:
1286
+ sample = upsample_block(
1287
+ hidden_states=sample,
1288
+ temb=emb,
1289
+ res_hidden_states_tuple=res_samples,
1290
+ upsample_size=upsample_size,
1291
+ scale=lora_scale,
1292
+ )
1293
+
1294
+ # 6. post-process
1295
+ # if self.conv_norm_out:
1296
+ # sample = self.conv_norm_out(sample)
1297
+ # sample = self.conv_act(sample)
1298
+ # sample = self.conv_out(sample)
1299
+
1300
+ if USE_PEFT_BACKEND:
1301
+ # remove `lora_scale` from each PEFT layer
1302
+ unscale_lora_layers(self, lora_scale)
1303
+
1304
+ if not return_dict:
1305
+ return (sample,)
1306
+
1307
+ return UNet2DConditionOutput(sample=sample)
models/unet_3d.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
17
+ from safetensors.torch import load_file
18
+
19
+ from .resnet import InflatedConv3d, InflatedGroupNorm
20
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ @dataclass
26
+ class UNet3DConditionOutput(BaseOutput):
27
+ sample: torch.FloatTensor
28
+
29
+
30
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
31
+ _supports_gradient_checkpointing = True
32
+
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ sample_size: Optional[int] = None,
37
+ in_channels: int = 4,
38
+ out_channels: int = 4,
39
+ center_input_sample: bool = False,
40
+ flip_sin_to_cos: bool = True,
41
+ freq_shift: int = 0,
42
+ down_block_types: Tuple[str] = (
43
+ "CrossAttnDownBlock3D",
44
+ "CrossAttnDownBlock3D",
45
+ "CrossAttnDownBlock3D",
46
+ "DownBlock3D",
47
+ ),
48
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
49
+ up_block_types: Tuple[str] = (
50
+ "UpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ "CrossAttnUpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ ),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ use_inflated_groupnorm=False,
72
+ # Additional
73
+ use_motion_module=False,
74
+ motion_module_resolutions=(1, 2, 4, 8),
75
+ motion_module_mid_block=False,
76
+ motion_module_decoder_only=False,
77
+ motion_module_type=None,
78
+ motion_module_kwargs={},
79
+ unet_use_cross_frame_attention=None,
80
+ unet_use_temporal_attention=None,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.sample_size = sample_size
85
+ time_embed_dim = block_out_channels[0] * 4
86
+
87
+ # input
88
+ self.conv_in = InflatedConv3d(
89
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
90
+ )
91
+
92
+ # time
93
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
94
+ timestep_input_dim = block_out_channels[0]
95
+
96
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97
+
98
+ # class embedding
99
+ if class_embed_type is None and num_class_embeds is not None:
100
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
101
+ elif class_embed_type == "timestep":
102
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+ elif class_embed_type == "identity":
104
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
105
+ else:
106
+ self.class_embedding = None
107
+
108
+ self.down_blocks = nn.ModuleList([])
109
+ self.mid_block = None
110
+ self.up_blocks = nn.ModuleList([])
111
+
112
+ if isinstance(only_cross_attention, bool):
113
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
114
+
115
+ if isinstance(attention_head_dim, int):
116
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
117
+
118
+ # down
119
+ output_channel = block_out_channels[0]
120
+ for i, down_block_type in enumerate(down_block_types):
121
+ res = 2**i
122
+ input_channel = output_channel
123
+ output_channel = block_out_channels[i]
124
+ is_final_block = i == len(block_out_channels) - 1
125
+
126
+ down_block = get_down_block(
127
+ down_block_type,
128
+ num_layers=layers_per_block,
129
+ in_channels=input_channel,
130
+ out_channels=output_channel,
131
+ temb_channels=time_embed_dim,
132
+ add_downsample=not is_final_block,
133
+ resnet_eps=norm_eps,
134
+ resnet_act_fn=act_fn,
135
+ resnet_groups=norm_num_groups,
136
+ cross_attention_dim=cross_attention_dim,
137
+ attn_num_head_channels=attention_head_dim[i],
138
+ downsample_padding=downsample_padding,
139
+ dual_cross_attention=dual_cross_attention,
140
+ use_linear_projection=use_linear_projection,
141
+ only_cross_attention=only_cross_attention[i],
142
+ upcast_attention=upcast_attention,
143
+ resnet_time_scale_shift=resnet_time_scale_shift,
144
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
145
+ unet_use_temporal_attention=unet_use_temporal_attention,
146
+ use_inflated_groupnorm=use_inflated_groupnorm,
147
+ use_motion_module=use_motion_module
148
+ and (res in motion_module_resolutions)
149
+ and (not motion_module_decoder_only),
150
+ motion_module_type=motion_module_type,
151
+ motion_module_kwargs=motion_module_kwargs,
152
+ )
153
+ self.down_blocks.append(down_block)
154
+
155
+ # mid
156
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
157
+ self.mid_block = UNetMidBlock3DCrossAttn(
158
+ in_channels=block_out_channels[-1],
159
+ temb_channels=time_embed_dim,
160
+ resnet_eps=norm_eps,
161
+ resnet_act_fn=act_fn,
162
+ output_scale_factor=mid_block_scale_factor,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ cross_attention_dim=cross_attention_dim,
165
+ attn_num_head_channels=attention_head_dim[-1],
166
+ resnet_groups=norm_num_groups,
167
+ dual_cross_attention=dual_cross_attention,
168
+ use_linear_projection=use_linear_projection,
169
+ upcast_attention=upcast_attention,
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+ use_inflated_groupnorm=use_inflated_groupnorm,
173
+ use_motion_module=use_motion_module and motion_module_mid_block,
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ else:
178
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
179
+
180
+ # count how many layers upsample the videos
181
+ self.num_upsamplers = 0
182
+
183
+ # up
184
+ reversed_block_out_channels = list(reversed(block_out_channels))
185
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
186
+ only_cross_attention = list(reversed(only_cross_attention))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ res = 2 ** (3 - i)
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[
195
+ min(i + 1, len(block_out_channels) - 1)
196
+ ]
197
+
198
+ # add upsample block for all BUT final layer
199
+ if not is_final_block:
200
+ add_upsample = True
201
+ self.num_upsamplers += 1
202
+ else:
203
+ add_upsample = False
204
+
205
+ up_block = get_up_block(
206
+ up_block_type,
207
+ num_layers=layers_per_block + 1,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ prev_output_channel=prev_output_channel,
211
+ temb_channels=time_embed_dim,
212
+ add_upsample=add_upsample,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ resnet_groups=norm_num_groups,
216
+ cross_attention_dim=cross_attention_dim,
217
+ attn_num_head_channels=reversed_attention_head_dim[i],
218
+ dual_cross_attention=dual_cross_attention,
219
+ use_linear_projection=use_linear_projection,
220
+ only_cross_attention=only_cross_attention[i],
221
+ upcast_attention=upcast_attention,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
224
+ unet_use_temporal_attention=unet_use_temporal_attention,
225
+ use_inflated_groupnorm=use_inflated_groupnorm,
226
+ use_motion_module=use_motion_module
227
+ and (res in motion_module_resolutions),
228
+ motion_module_type=motion_module_type,
229
+ motion_module_kwargs=motion_module_kwargs,
230
+ )
231
+ self.up_blocks.append(up_block)
232
+ prev_output_channel = output_channel
233
+
234
+ # out
235
+ if use_inflated_groupnorm:
236
+ self.conv_norm_out = InflatedGroupNorm(
237
+ num_channels=block_out_channels[0],
238
+ num_groups=norm_num_groups,
239
+ eps=norm_eps,
240
+ )
241
+ else:
242
+ self.conv_norm_out = nn.GroupNorm(
243
+ num_channels=block_out_channels[0],
244
+ num_groups=norm_num_groups,
245
+ eps=norm_eps,
246
+ )
247
+ self.conv_act = nn.SiLU()
248
+ self.conv_out = InflatedConv3d(
249
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
250
+ )
251
+
252
+ @property
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
254
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
255
+ r"""
256
+ Returns:
257
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
258
+ indexed by its weight name.
259
+ """
260
+ # set recursively
261
+ processors = {}
262
+
263
+ def fn_recursive_add_processors(
264
+ name: str,
265
+ module: torch.nn.Module,
266
+ processors: Dict[str, AttentionProcessor],
267
+ ):
268
+ if hasattr(module, "set_processor"):
269
+ processors[f"{name}.processor"] = module.processor
270
+
271
+ for sub_name, child in module.named_children():
272
+ if "temporal_transformer" not in sub_name:
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ if "temporal_transformer" not in name:
279
+ fn_recursive_add_processors(name, module, processors)
280
+
281
+ return processors
282
+
283
+ def set_attention_slice(self, slice_size):
284
+ r"""
285
+ Enable sliced attention computation.
286
+
287
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
288
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
289
+
290
+ Args:
291
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
292
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
293
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
294
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
295
+ must be a multiple of `slice_size`.
296
+ """
297
+ sliceable_head_dims = []
298
+
299
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
300
+ if hasattr(module, "set_attention_slice"):
301
+ sliceable_head_dims.append(module.sliceable_head_dim)
302
+
303
+ for child in module.children():
304
+ fn_recursive_retrieve_slicable_dims(child)
305
+
306
+ # retrieve number of attention layers
307
+ for module in self.children():
308
+ fn_recursive_retrieve_slicable_dims(module)
309
+
310
+ num_slicable_layers = len(sliceable_head_dims)
311
+
312
+ if slice_size == "auto":
313
+ # half the attention head size is usually a good trade-off between
314
+ # speed and memory
315
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
316
+ elif slice_size == "max":
317
+ # make smallest slice possible
318
+ slice_size = num_slicable_layers * [1]
319
+
320
+ slice_size = (
321
+ num_slicable_layers * [slice_size]
322
+ if not isinstance(slice_size, list)
323
+ else slice_size
324
+ )
325
+
326
+ if len(slice_size) != len(sliceable_head_dims):
327
+ raise ValueError(
328
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
329
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
330
+ )
331
+
332
+ for i in range(len(slice_size)):
333
+ size = slice_size[i]
334
+ dim = sliceable_head_dims[i]
335
+ if size is not None and size > dim:
336
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
337
+
338
+ # Recursively walk through all the children.
339
+ # Any children which exposes the set_attention_slice method
340
+ # gets the message
341
+ def fn_recursive_set_attention_slice(
342
+ module: torch.nn.Module, slice_size: List[int]
343
+ ):
344
+ if hasattr(module, "set_attention_slice"):
345
+ module.set_attention_slice(slice_size.pop())
346
+
347
+ for child in module.children():
348
+ fn_recursive_set_attention_slice(child, slice_size)
349
+
350
+ reversed_slice_size = list(reversed(slice_size))
351
+ for module in self.children():
352
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
353
+
354
+ def _set_gradient_checkpointing(self, module, value=False):
355
+ if hasattr(module, "gradient_checkpointing"):
356
+ module.gradient_checkpointing = value
357
+
358
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
359
+ def set_attn_processor(
360
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
361
+ ):
362
+ r"""
363
+ Sets the attention processor to use to compute attention.
364
+
365
+ Parameters:
366
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
367
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
368
+ for **all** `Attention` layers.
369
+
370
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
371
+ processor. This is strongly recommended when setting trainable attention processors.
372
+
373
+ """
374
+ count = len(self.attn_processors.keys())
375
+
376
+ if isinstance(processor, dict) and len(processor) != count:
377
+ raise ValueError(
378
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
379
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
380
+ )
381
+
382
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
383
+ if hasattr(module, "set_processor"):
384
+ if not isinstance(processor, dict):
385
+ module.set_processor(processor)
386
+ else:
387
+ module.set_processor(processor.pop(f"{name}.processor"))
388
+
389
+ for sub_name, child in module.named_children():
390
+ if "temporal_transformer" not in sub_name:
391
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
392
+
393
+ for name, module in self.named_children():
394
+ if "temporal_transformer" not in name:
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def forward(
398
+ self,
399
+ sample: torch.FloatTensor,
400
+ timestep: Union[torch.Tensor, float, int],
401
+ encoder_hidden_states: torch.Tensor,
402
+ class_labels: Optional[torch.Tensor] = None,
403
+ pose_cond_fea: Optional[torch.Tensor] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
406
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
407
+ return_dict: bool = True,
408
+ ) -> Union[UNet3DConditionOutput, Tuple]:
409
+ r"""
410
+ Args:
411
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
412
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
413
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
414
+ return_dict (`bool`, *optional*, defaults to `True`):
415
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
416
+
417
+ Returns:
418
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
419
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
420
+ returning a tuple, the first element is the sample tensor.
421
+ """
422
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
423
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
424
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
425
+ # on the fly if necessary.
426
+ default_overall_up_factor = 2**self.num_upsamplers
427
+
428
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
429
+ forward_upsample_size = False
430
+ upsample_size = None
431
+
432
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
433
+ logger.info("Forward upsample size to force interpolation output size.")
434
+ forward_upsample_size = True
435
+
436
+ # prepare attention_mask
437
+ if attention_mask is not None:
438
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
439
+ attention_mask = attention_mask.unsqueeze(1)
440
+
441
+ # center input if necessary
442
+ if self.config.center_input_sample:
443
+ sample = 2 * sample - 1.0
444
+
445
+ # time
446
+ timesteps = timestep
447
+ if not torch.is_tensor(timesteps):
448
+ # This would be a good case for the `match` statement (Python 3.10+)
449
+ is_mps = sample.device.type == "mps"
450
+ if isinstance(timestep, float):
451
+ dtype = torch.float32 if is_mps else torch.float64
452
+ else:
453
+ dtype = torch.int32 if is_mps else torch.int64
454
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
455
+ elif len(timesteps.shape) == 0:
456
+ timesteps = timesteps[None].to(sample.device)
457
+
458
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
459
+ timesteps = timesteps.expand(sample.shape[0])
460
+
461
+ t_emb = self.time_proj(timesteps)
462
+
463
+ # timesteps does not contain any weights and will always return f32 tensors
464
+ # but time_embedding might actually be running in fp16. so we need to cast here.
465
+ # there might be better ways to encapsulate this.
466
+ t_emb = t_emb.to(dtype=self.dtype)
467
+ emb = self.time_embedding(t_emb)
468
+
469
+ if self.class_embedding is not None:
470
+ if class_labels is None:
471
+ raise ValueError(
472
+ "class_labels should be provided when num_class_embeds > 0"
473
+ )
474
+
475
+ if self.config.class_embed_type == "timestep":
476
+ class_labels = self.time_proj(class_labels)
477
+
478
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
479
+ emb = emb + class_emb
480
+
481
+ # pre-process
482
+ sample = self.conv_in(sample)
483
+ if pose_cond_fea is not None:
484
+ sample = sample + pose_cond_fea
485
+
486
+ # down
487
+ down_block_res_samples = (sample,)
488
+ for downsample_block in self.down_blocks:
489
+ if (
490
+ hasattr(downsample_block, "has_cross_attention")
491
+ and downsample_block.has_cross_attention
492
+ ):
493
+ sample, res_samples = downsample_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ encoder_hidden_states=encoder_hidden_states,
497
+ attention_mask=attention_mask,
498
+ )
499
+ else:
500
+ sample, res_samples = downsample_block(
501
+ hidden_states=sample,
502
+ temb=emb,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ )
505
+
506
+ down_block_res_samples += res_samples
507
+
508
+ if down_block_additional_residuals is not None:
509
+ new_down_block_res_samples = ()
510
+
511
+ for down_block_res_sample, down_block_additional_residual in zip(
512
+ down_block_res_samples, down_block_additional_residuals
513
+ ):
514
+ down_block_res_sample = (
515
+ down_block_res_sample + down_block_additional_residual
516
+ )
517
+ new_down_block_res_samples += (down_block_res_sample,)
518
+
519
+ down_block_res_samples = new_down_block_res_samples
520
+
521
+ # mid
522
+ sample = self.mid_block(
523
+ sample,
524
+ emb,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ )
528
+
529
+ if mid_block_additional_residual is not None:
530
+ sample = sample + mid_block_additional_residual
531
+
532
+ # up
533
+ for i, upsample_block in enumerate(self.up_blocks):
534
+ is_final_block = i == len(self.up_blocks) - 1
535
+
536
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
537
+ down_block_res_samples = down_block_res_samples[
538
+ : -len(upsample_block.resnets)
539
+ ]
540
+
541
+ # if we have not reached the final block and need to forward the
542
+ # upsample size, we do it here
543
+ if not is_final_block and forward_upsample_size:
544
+ upsample_size = down_block_res_samples[-1].shape[2:]
545
+
546
+ if (
547
+ hasattr(upsample_block, "has_cross_attention")
548
+ and upsample_block.has_cross_attention
549
+ ):
550
+ sample = upsample_block(
551
+ hidden_states=sample,
552
+ temb=emb,
553
+ res_hidden_states_tuple=res_samples,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ upsample_size=upsample_size,
556
+ attention_mask=attention_mask,
557
+ )
558
+ else:
559
+ sample = upsample_block(
560
+ hidden_states=sample,
561
+ temb=emb,
562
+ res_hidden_states_tuple=res_samples,
563
+ upsample_size=upsample_size,
564
+ encoder_hidden_states=encoder_hidden_states,
565
+ )
566
+
567
+ # post-process
568
+ sample = self.conv_norm_out(sample)
569
+ sample = self.conv_act(sample)
570
+ sample = self.conv_out(sample)
571
+
572
+ if not return_dict:
573
+ return (sample,)
574
+
575
+ return UNet3DConditionOutput(sample=sample)
576
+
577
+ @classmethod
578
+ def from_pretrained_2d(
579
+ cls,
580
+ pretrained_model_path: PathLike,
581
+ motion_module_path: PathLike,
582
+ subfolder=None,
583
+ unet_additional_kwargs=None,
584
+ mm_zero_proj_out=False,
585
+ ):
586
+ pretrained_model_path = Path(pretrained_model_path)
587
+ motion_module_path = Path(motion_module_path)
588
+ if subfolder is not None:
589
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
590
+ logger.info(
591
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
592
+ )
593
+
594
+ config_file = pretrained_model_path / "config.json"
595
+ if not (config_file.exists() and config_file.is_file()):
596
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
597
+
598
+ unet_config = cls.load_config(config_file)
599
+ unet_config["_class_name"] = cls.__name__
600
+ unet_config["down_block_types"] = [
601
+ "CrossAttnDownBlock3D",
602
+ "CrossAttnDownBlock3D",
603
+ "CrossAttnDownBlock3D",
604
+ "DownBlock3D",
605
+ ]
606
+ unet_config["up_block_types"] = [
607
+ "UpBlock3D",
608
+ "CrossAttnUpBlock3D",
609
+ "CrossAttnUpBlock3D",
610
+ "CrossAttnUpBlock3D",
611
+ ]
612
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
613
+
614
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
615
+ # load the vanilla weights
616
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
617
+ logger.debug(
618
+ f"loading safeTensors weights from {pretrained_model_path} ..."
619
+ )
620
+ state_dict = load_file(
621
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
622
+ )
623
+
624
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
625
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
626
+ state_dict = torch.load(
627
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
628
+ map_location="cpu",
629
+ weights_only=True,
630
+ )
631
+ else:
632
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
633
+
634
+ # load the motion module weights
635
+ if motion_module_path.exists() and motion_module_path.is_file():
636
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
637
+ logger.info(f"Load motion module params from {motion_module_path}")
638
+ motion_state_dict = torch.load(
639
+ motion_module_path, map_location="cpu", weights_only=True
640
+ )
641
+ elif motion_module_path.suffix.lower() == ".safetensors":
642
+ motion_state_dict = load_file(motion_module_path, device="cpu")
643
+ else:
644
+ raise RuntimeError(
645
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
646
+ )
647
+ if mm_zero_proj_out:
648
+ logger.info(f"Zero initialize proj_out layers in motion module...")
649
+ new_motion_state_dict = OrderedDict()
650
+ for k in motion_state_dict:
651
+ if "proj_out" in k:
652
+ continue
653
+ new_motion_state_dict[k] = motion_state_dict[k]
654
+ motion_state_dict = new_motion_state_dict
655
+
656
+
657
+
658
+ for weight_name in list(motion_state_dict.keys()):
659
+ if weight_name[-2:]== 'pe':
660
+ del motion_state_dict[weight_name]
661
+ # print(weight_name)
662
+
663
+ # merge the state dicts
664
+ state_dict.update(motion_state_dict)
665
+
666
+ # load the weights into the model
667
+ m, u = model.load_state_dict(state_dict, strict=False)
668
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
669
+
670
+ params = [
671
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
672
+ ]
673
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
674
+
675
+ return model
models/unet_3d_blocks.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ if motion_module is not None:
433
+ hidden_states = torch.utils.checkpoint.checkpoint(
434
+ create_custom_forward(motion_module),
435
+ hidden_states.requires_grad_(),
436
+ temb,
437
+ encoder_hidden_states,
438
+ )
439
+
440
+ # # add motion module
441
+ # hidden_states = (
442
+ # motion_module(
443
+ # hidden_states, temb, encoder_hidden_states=encoder_hidden_states
444
+ # )
445
+ # if motion_module is not None
446
+ # else hidden_states
447
+ # )
448
+
449
+ else:
450
+ hidden_states = resnet(hidden_states, temb)
451
+ hidden_states = attn(
452
+ hidden_states,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ ).sample
455
+
456
+ # add motion module
457
+ hidden_states = (
458
+ motion_module(
459
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
460
+ )
461
+ if motion_module is not None
462
+ else hidden_states
463
+ )
464
+
465
+ output_states += (hidden_states,)
466
+
467
+ if self.downsamplers is not None:
468
+ for downsampler in self.downsamplers:
469
+ hidden_states = downsampler(hidden_states)
470
+
471
+ output_states += (hidden_states,)
472
+
473
+ return hidden_states, output_states
474
+
475
+
476
+ class DownBlock3D(nn.Module):
477
+ def __init__(
478
+ self,
479
+ in_channels: int,
480
+ out_channels: int,
481
+ temb_channels: int,
482
+ dropout: float = 0.0,
483
+ num_layers: int = 1,
484
+ resnet_eps: float = 1e-6,
485
+ resnet_time_scale_shift: str = "default",
486
+ resnet_act_fn: str = "swish",
487
+ resnet_groups: int = 32,
488
+ resnet_pre_norm: bool = True,
489
+ output_scale_factor=1.0,
490
+ add_downsample=True,
491
+ downsample_padding=1,
492
+ use_inflated_groupnorm=None,
493
+ use_motion_module=None,
494
+ motion_module_type=None,
495
+ motion_module_kwargs=None,
496
+ ):
497
+ super().__init__()
498
+ resnets = []
499
+ motion_modules = []
500
+
501
+ # use_motion_module = False
502
+ for i in range(num_layers):
503
+ in_channels = in_channels if i == 0 else out_channels
504
+ resnets.append(
505
+ ResnetBlock3D(
506
+ in_channels=in_channels,
507
+ out_channels=out_channels,
508
+ temb_channels=temb_channels,
509
+ eps=resnet_eps,
510
+ groups=resnet_groups,
511
+ dropout=dropout,
512
+ time_embedding_norm=resnet_time_scale_shift,
513
+ non_linearity=resnet_act_fn,
514
+ output_scale_factor=output_scale_factor,
515
+ pre_norm=resnet_pre_norm,
516
+ use_inflated_groupnorm=use_inflated_groupnorm,
517
+ )
518
+ )
519
+ motion_modules.append(
520
+ get_motion_module(
521
+ in_channels=out_channels,
522
+ motion_module_type=motion_module_type,
523
+ motion_module_kwargs=motion_module_kwargs,
524
+ )
525
+ if use_motion_module
526
+ else None
527
+ )
528
+
529
+ self.resnets = nn.ModuleList(resnets)
530
+ self.motion_modules = nn.ModuleList(motion_modules)
531
+
532
+ if add_downsample:
533
+ self.downsamplers = nn.ModuleList(
534
+ [
535
+ Downsample3D(
536
+ out_channels,
537
+ use_conv=True,
538
+ out_channels=out_channels,
539
+ padding=downsample_padding,
540
+ name="op",
541
+ )
542
+ ]
543
+ )
544
+ else:
545
+ self.downsamplers = None
546
+
547
+ self.gradient_checkpointing = False
548
+
549
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
550
+ output_states = ()
551
+
552
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
553
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
554
+ if self.training and self.gradient_checkpointing:
555
+
556
+ def create_custom_forward(module):
557
+ def custom_forward(*inputs):
558
+ return module(*inputs)
559
+
560
+ return custom_forward
561
+
562
+ hidden_states = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(resnet), hidden_states, temb
564
+ )
565
+ if motion_module is not None:
566
+ hidden_states = torch.utils.checkpoint.checkpoint(
567
+ create_custom_forward(motion_module),
568
+ hidden_states.requires_grad_(),
569
+ temb,
570
+ encoder_hidden_states,
571
+ )
572
+ else:
573
+ hidden_states = resnet(hidden_states, temb)
574
+
575
+ # add motion module
576
+ hidden_states = (
577
+ motion_module(
578
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
579
+ )
580
+ if motion_module is not None
581
+ else hidden_states
582
+ )
583
+
584
+ output_states += (hidden_states,)
585
+
586
+ if self.downsamplers is not None:
587
+ for downsampler in self.downsamplers:
588
+ hidden_states = downsampler(hidden_states)
589
+
590
+ output_states += (hidden_states,)
591
+
592
+ return hidden_states, output_states
593
+
594
+
595
+ class CrossAttnUpBlock3D(nn.Module):
596
+ def __init__(
597
+ self,
598
+ in_channels: int,
599
+ out_channels: int,
600
+ prev_output_channel: int,
601
+ temb_channels: int,
602
+ dropout: float = 0.0,
603
+ num_layers: int = 1,
604
+ resnet_eps: float = 1e-6,
605
+ resnet_time_scale_shift: str = "default",
606
+ resnet_act_fn: str = "swish",
607
+ resnet_groups: int = 32,
608
+ resnet_pre_norm: bool = True,
609
+ attn_num_head_channels=1,
610
+ cross_attention_dim=1280,
611
+ output_scale_factor=1.0,
612
+ add_upsample=True,
613
+ dual_cross_attention=False,
614
+ use_linear_projection=False,
615
+ only_cross_attention=False,
616
+ upcast_attention=False,
617
+ unet_use_cross_frame_attention=None,
618
+ unet_use_temporal_attention=None,
619
+ use_motion_module=None,
620
+ use_inflated_groupnorm=None,
621
+ motion_module_type=None,
622
+ motion_module_kwargs=None,
623
+ ):
624
+ super().__init__()
625
+ resnets = []
626
+ attentions = []
627
+ motion_modules = []
628
+
629
+ self.has_cross_attention = True
630
+ self.attn_num_head_channels = attn_num_head_channels
631
+
632
+ for i in range(num_layers):
633
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
634
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
635
+
636
+ resnets.append(
637
+ ResnetBlock3D(
638
+ in_channels=resnet_in_channels + res_skip_channels,
639
+ out_channels=out_channels,
640
+ temb_channels=temb_channels,
641
+ eps=resnet_eps,
642
+ groups=resnet_groups,
643
+ dropout=dropout,
644
+ time_embedding_norm=resnet_time_scale_shift,
645
+ non_linearity=resnet_act_fn,
646
+ output_scale_factor=output_scale_factor,
647
+ pre_norm=resnet_pre_norm,
648
+ use_inflated_groupnorm=use_inflated_groupnorm,
649
+ )
650
+ )
651
+ if dual_cross_attention:
652
+ raise NotImplementedError
653
+ attentions.append(
654
+ Transformer3DModel(
655
+ attn_num_head_channels,
656
+ out_channels // attn_num_head_channels,
657
+ in_channels=out_channels,
658
+ num_layers=1,
659
+ cross_attention_dim=cross_attention_dim,
660
+ norm_num_groups=resnet_groups,
661
+ use_linear_projection=use_linear_projection,
662
+ only_cross_attention=only_cross_attention,
663
+ upcast_attention=upcast_attention,
664
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
665
+ unet_use_temporal_attention=unet_use_temporal_attention,
666
+ )
667
+ )
668
+ motion_modules.append(
669
+ get_motion_module(
670
+ in_channels=out_channels,
671
+ motion_module_type=motion_module_type,
672
+ motion_module_kwargs=motion_module_kwargs,
673
+ )
674
+ if use_motion_module
675
+ else None
676
+ )
677
+
678
+ self.attentions = nn.ModuleList(attentions)
679
+ self.resnets = nn.ModuleList(resnets)
680
+ self.motion_modules = nn.ModuleList(motion_modules)
681
+
682
+ if add_upsample:
683
+ self.upsamplers = nn.ModuleList(
684
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
685
+ )
686
+ else:
687
+ self.upsamplers = None
688
+
689
+ self.gradient_checkpointing = False
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states,
694
+ res_hidden_states_tuple,
695
+ temb=None,
696
+ encoder_hidden_states=None,
697
+ upsample_size=None,
698
+ attention_mask=None,
699
+ ):
700
+ for i, (resnet, attn, motion_module) in enumerate(
701
+ zip(self.resnets, self.attentions, self.motion_modules)
702
+ ):
703
+ # pop res hidden states
704
+ res_hidden_states = res_hidden_states_tuple[-1]
705
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
706
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
707
+
708
+ if self.training and self.gradient_checkpointing:
709
+
710
+ def create_custom_forward(module, return_dict=None):
711
+ def custom_forward(*inputs):
712
+ if return_dict is not None:
713
+ return module(*inputs, return_dict=return_dict)
714
+ else:
715
+ return module(*inputs)
716
+
717
+ return custom_forward
718
+
719
+ hidden_states = torch.utils.checkpoint.checkpoint(
720
+ create_custom_forward(resnet), hidden_states, temb
721
+ )
722
+ hidden_states = attn(
723
+ hidden_states,
724
+ encoder_hidden_states=encoder_hidden_states,
725
+ ).sample
726
+ if motion_module is not None:
727
+ hidden_states = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(motion_module),
729
+ hidden_states.requires_grad_(),
730
+ temb,
731
+ encoder_hidden_states,
732
+ )
733
+
734
+ else:
735
+ hidden_states = resnet(hidden_states, temb)
736
+ hidden_states = attn(
737
+ hidden_states,
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ ).sample
740
+
741
+ # add motion module
742
+ hidden_states = (
743
+ motion_module(
744
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
745
+ )
746
+ if motion_module is not None
747
+ else hidden_states
748
+ )
749
+
750
+ if self.upsamplers is not None:
751
+ for upsampler in self.upsamplers:
752
+ hidden_states = upsampler(hidden_states, upsample_size)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class UpBlock3D(nn.Module):
758
+ def __init__(
759
+ self,
760
+ in_channels: int,
761
+ prev_output_channel: int,
762
+ out_channels: int,
763
+ temb_channels: int,
764
+ dropout: float = 0.0,
765
+ num_layers: int = 1,
766
+ resnet_eps: float = 1e-6,
767
+ resnet_time_scale_shift: str = "default",
768
+ resnet_act_fn: str = "swish",
769
+ resnet_groups: int = 32,
770
+ resnet_pre_norm: bool = True,
771
+ output_scale_factor=1.0,
772
+ add_upsample=True,
773
+ use_inflated_groupnorm=None,
774
+ use_motion_module=None,
775
+ motion_module_type=None,
776
+ motion_module_kwargs=None,
777
+ ):
778
+ super().__init__()
779
+ resnets = []
780
+ motion_modules = []
781
+
782
+ # use_motion_module = False
783
+ for i in range(num_layers):
784
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
785
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
786
+
787
+ resnets.append(
788
+ ResnetBlock3D(
789
+ in_channels=resnet_in_channels + res_skip_channels,
790
+ out_channels=out_channels,
791
+ temb_channels=temb_channels,
792
+ eps=resnet_eps,
793
+ groups=resnet_groups,
794
+ dropout=dropout,
795
+ time_embedding_norm=resnet_time_scale_shift,
796
+ non_linearity=resnet_act_fn,
797
+ output_scale_factor=output_scale_factor,
798
+ pre_norm=resnet_pre_norm,
799
+ use_inflated_groupnorm=use_inflated_groupnorm,
800
+ )
801
+ )
802
+ motion_modules.append(
803
+ get_motion_module(
804
+ in_channels=out_channels,
805
+ motion_module_type=motion_module_type,
806
+ motion_module_kwargs=motion_module_kwargs,
807
+ )
808
+ if use_motion_module
809
+ else None
810
+ )
811
+
812
+ self.resnets = nn.ModuleList(resnets)
813
+ self.motion_modules = nn.ModuleList(motion_modules)
814
+
815
+ if add_upsample:
816
+ self.upsamplers = nn.ModuleList(
817
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
818
+ )
819
+ else:
820
+ self.upsamplers = None
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(
825
+ self,
826
+ hidden_states,
827
+ res_hidden_states_tuple,
828
+ temb=None,
829
+ upsample_size=None,
830
+ encoder_hidden_states=None,
831
+ ):
832
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
833
+ # pop res hidden states
834
+ res_hidden_states = res_hidden_states_tuple[-1]
835
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
836
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
837
+
838
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
839
+ if self.training and self.gradient_checkpointing:
840
+
841
+ def create_custom_forward(module):
842
+ def custom_forward(*inputs):
843
+ return module(*inputs)
844
+
845
+ return custom_forward
846
+
847
+ hidden_states = torch.utils.checkpoint.checkpoint(
848
+ create_custom_forward(resnet), hidden_states, temb
849
+ )
850
+ if motion_module is not None:
851
+ hidden_states = torch.utils.checkpoint.checkpoint(
852
+ create_custom_forward(motion_module),
853
+ hidden_states.requires_grad_(),
854
+ temb,
855
+ encoder_hidden_states,
856
+ )
857
+ else:
858
+ hidden_states = resnet(hidden_states, temb)
859
+ hidden_states = (
860
+ motion_module(
861
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
862
+ )
863
+ if motion_module is not None
864
+ else hidden_states
865
+ )
866
+
867
+ if self.upsamplers is not None:
868
+ for upsampler in self.upsamplers:
869
+ hidden_states = upsampler(hidden_states, upsample_size)
870
+
871
+ return hidden_states
musepose/__init__.py ADDED
File without changes
musepose/dataset/dance_image.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from decord import VideoReader
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import CLIPImageProcessor
10
+
11
+
12
+ class HumanDanceDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ img_size,
16
+ img_scale=(1.0, 1.0),
17
+ img_ratio=(0.9, 1.0),
18
+ drop_ratio=0.1,
19
+ data_meta_paths=["./data/fahsion_meta.json"],
20
+ sample_margin=30,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.img_size = img_size
25
+ self.img_scale = img_scale
26
+ self.img_ratio = img_ratio
27
+ self.sample_margin = sample_margin
28
+
29
+ # -----
30
+ # vid_meta format:
31
+ # [{'video_path': , 'kps_path': , 'other':},
32
+ # {'video_path': , 'kps_path': , 'other':}]
33
+ # -----
34
+ vid_meta = []
35
+ for data_meta_path in data_meta_paths:
36
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
37
+ self.vid_meta = vid_meta
38
+
39
+ self.clip_image_processor = CLIPImageProcessor()
40
+
41
+ self.transform = transforms.Compose(
42
+ [
43
+ # transforms.RandomResizedCrop(
44
+ # self.img_size,
45
+ # scale=self.img_scale,
46
+ # ratio=self.img_ratio,
47
+ # interpolation=transforms.InterpolationMode.BILINEAR,
48
+ # ),
49
+ transforms.Resize(
50
+ self.img_size,
51
+ ),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5], [0.5]),
54
+ ]
55
+ )
56
+
57
+ self.cond_transform = transforms.Compose(
58
+ [
59
+ # transforms.RandomResizedCrop(
60
+ # self.img_size,
61
+ # scale=self.img_scale,
62
+ # ratio=self.img_ratio,
63
+ # interpolation=transforms.InterpolationMode.BILINEAR,
64
+ # ),
65
+ transforms.Resize(
66
+ self.img_size,
67
+ ),
68
+ transforms.ToTensor(),
69
+ ]
70
+ )
71
+
72
+ self.drop_ratio = drop_ratio
73
+
74
+ def augmentation(self, image, transform, state=None):
75
+ if state is not None:
76
+ torch.set_rng_state(state)
77
+ return transform(image)
78
+
79
+ def __getitem__(self, index):
80
+ video_meta = self.vid_meta[index]
81
+ video_path = video_meta["video_path"]
82
+ kps_path = video_meta["kps_path"]
83
+
84
+ video_reader = VideoReader(video_path)
85
+ kps_reader = VideoReader(kps_path)
86
+
87
+ assert len(video_reader) == len(
88
+ kps_reader
89
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90
+
91
+ video_length = len(video_reader)
92
+
93
+ margin = min(self.sample_margin, video_length)
94
+
95
+ ref_img_idx = random.randint(0, video_length - 1)
96
+ if ref_img_idx + margin < video_length:
97
+ tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
98
+ elif ref_img_idx - margin > 0:
99
+ tgt_img_idx = random.randint(0, ref_img_idx - margin)
100
+ else:
101
+ tgt_img_idx = random.randint(0, video_length - 1)
102
+
103
+ ref_img = video_reader[ref_img_idx]
104
+ ref_img_pil = Image.fromarray(ref_img.asnumpy())
105
+ tgt_img = video_reader[tgt_img_idx]
106
+ tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
107
+
108
+ tgt_pose = kps_reader[tgt_img_idx]
109
+ tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
110
+
111
+ state = torch.get_rng_state()
112
+ tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
113
+ tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
114
+ ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
115
+ clip_image = self.clip_image_processor(
116
+ images=ref_img_pil, return_tensors="pt"
117
+ ).pixel_values[0]
118
+
119
+ sample = dict(
120
+ video_dir=video_path,
121
+ img=tgt_img,
122
+ tgt_pose=tgt_pose_img,
123
+ ref_img=ref_img_vae,
124
+ clip_images=clip_image,
125
+ )
126
+
127
+ return sample
128
+
129
+ def __len__(self):
130
+ return len(self.vid_meta)
musepose/dataset/dance_video.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from decord import VideoReader
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from transformers import CLIPImageProcessor
13
+
14
+
15
+ class HumanDanceVideoDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ sample_rate,
19
+ n_sample_frames,
20
+ width,
21
+ height,
22
+ img_scale=(1.0, 1.0),
23
+ img_ratio=(0.9, 1.0),
24
+ drop_ratio=0.1,
25
+ data_meta_paths=["./data/fashion_meta.json"],
26
+ ):
27
+ super().__init__()
28
+ self.sample_rate = sample_rate
29
+ self.n_sample_frames = n_sample_frames
30
+ self.width = width
31
+ self.height = height
32
+ self.img_scale = img_scale
33
+ self.img_ratio = img_ratio
34
+
35
+ vid_meta = []
36
+ for data_meta_path in data_meta_paths:
37
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
38
+ self.vid_meta = vid_meta
39
+
40
+ self.clip_image_processor = CLIPImageProcessor()
41
+
42
+ self.pixel_transform = transforms.Compose(
43
+ [
44
+ # transforms.RandomResizedCrop(
45
+ # (height, width),
46
+ # scale=self.img_scale,
47
+ # ratio=self.img_ratio,
48
+ # interpolation=transforms.InterpolationMode.BILINEAR,
49
+ # ),
50
+ transforms.Resize(
51
+ (height, width),
52
+ ),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.5], [0.5]),
55
+ ]
56
+ )
57
+
58
+ self.cond_transform = transforms.Compose(
59
+ [
60
+ # transforms.RandomResizedCrop(
61
+ # (height, width),
62
+ # scale=self.img_scale,
63
+ # ratio=self.img_ratio,
64
+ # interpolation=transforms.InterpolationMode.BILINEAR,
65
+ # ),
66
+ transforms.Resize(
67
+ (height, width),
68
+ ),
69
+ transforms.ToTensor(),
70
+ ]
71
+ )
72
+
73
+ self.drop_ratio = drop_ratio
74
+
75
+ def augmentation(self, images, transform, state=None):
76
+ if state is not None:
77
+ torch.set_rng_state(state)
78
+ if isinstance(images, List):
79
+ transformed_images = [transform(img) for img in images]
80
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
81
+ else:
82
+ ret_tensor = transform(images) # (c, h, w)
83
+ return ret_tensor
84
+
85
+ def __getitem__(self, index):
86
+ video_meta = self.vid_meta[index]
87
+ video_path = video_meta["video_path"]
88
+ kps_path = video_meta["kps_path"]
89
+
90
+ video_reader = VideoReader(video_path)
91
+ kps_reader = VideoReader(kps_path)
92
+
93
+ assert len(video_reader) == len(
94
+ kps_reader
95
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
96
+
97
+ video_length = len(video_reader)
98
+ video_fps = video_reader.get_avg_fps()
99
+ # print("fps", video_fps)
100
+ if video_fps > 30: # 30-60
101
+ sample_rate = self.sample_rate*2
102
+ else:
103
+ sample_rate = self.sample_rate
104
+
105
+
106
+ clip_length = min(
107
+ video_length, (self.n_sample_frames - 1) * sample_rate + 1
108
+ )
109
+ start_idx = random.randint(0, video_length - clip_length)
110
+ batch_index = np.linspace(
111
+ start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
112
+ ).tolist()
113
+
114
+ # read frames and kps
115
+ vid_pil_image_list = []
116
+ pose_pil_image_list = []
117
+ for index in batch_index:
118
+ img = video_reader[index]
119
+ vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
120
+ img = kps_reader[index]
121
+ pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
122
+
123
+ ref_img_idx = random.randint(0, video_length - 1)
124
+ ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
125
+
126
+ # transform
127
+ state = torch.get_rng_state()
128
+ pixel_values_vid = self.augmentation(
129
+ vid_pil_image_list, self.pixel_transform, state
130
+ )
131
+ pixel_values_pose = self.augmentation(
132
+ pose_pil_image_list, self.cond_transform, state
133
+ )
134
+ pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
135
+ clip_ref_img = self.clip_image_processor(
136
+ images=ref_img, return_tensors="pt"
137
+ ).pixel_values[0]
138
+
139
+ sample = dict(
140
+ video_dir=video_path,
141
+ pixel_values_vid=pixel_values_vid,
142
+ pixel_values_pose=pixel_values_pose,
143
+ pixel_values_ref_img=pixel_values_ref_img,
144
+ clip_ref_img=clip_ref_img,
145
+ )
146
+
147
+ return sample
148
+
149
+ def __len__(self):
150
+ return len(self.vid_meta)
musepose/models/attention.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ ):
314
+ super().__init__()
315
+ self.only_cross_attention = only_cross_attention
316
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
317
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318
+ self.unet_use_temporal_attention = unet_use_temporal_attention
319
+
320
+ # SC-Attn
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ upcast_attention=upcast_attention,
328
+ )
329
+ self.norm1 = (
330
+ AdaLayerNorm(dim, num_embeds_ada_norm)
331
+ if self.use_ada_layer_norm
332
+ else nn.LayerNorm(dim)
333
+ )
334
+
335
+ # Cross-Attn
336
+ if cross_attention_dim is not None:
337
+ self.attn2 = Attention(
338
+ query_dim=dim,
339
+ cross_attention_dim=cross_attention_dim,
340
+ heads=num_attention_heads,
341
+ dim_head=attention_head_dim,
342
+ dropout=dropout,
343
+ bias=attention_bias,
344
+ upcast_attention=upcast_attention,
345
+ )
346
+ else:
347
+ self.attn2 = None
348
+
349
+ if cross_attention_dim is not None:
350
+ self.norm2 = (
351
+ AdaLayerNorm(dim, num_embeds_ada_norm)
352
+ if self.use_ada_layer_norm
353
+ else nn.LayerNorm(dim)
354
+ )
355
+ else:
356
+ self.norm2 = None
357
+
358
+ # Feed-forward
359
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
360
+ self.norm3 = nn.LayerNorm(dim)
361
+ self.use_ada_layer_norm_zero = False
362
+
363
+ # Temp-Attn
364
+ assert unet_use_temporal_attention is not None
365
+ if unet_use_temporal_attention:
366
+ self.attn_temp = Attention(
367
+ query_dim=dim,
368
+ heads=num_attention_heads,
369
+ dim_head=attention_head_dim,
370
+ dropout=dropout,
371
+ bias=attention_bias,
372
+ upcast_attention=upcast_attention,
373
+ )
374
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
375
+ self.norm_temp = (
376
+ AdaLayerNorm(dim, num_embeds_ada_norm)
377
+ if self.use_ada_layer_norm
378
+ else nn.LayerNorm(dim)
379
+ )
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states,
384
+ encoder_hidden_states=None,
385
+ timestep=None,
386
+ attention_mask=None,
387
+ video_length=None,
388
+ ):
389
+ norm_hidden_states = (
390
+ self.norm1(hidden_states, timestep)
391
+ if self.use_ada_layer_norm
392
+ else self.norm1(hidden_states)
393
+ )
394
+
395
+ if self.unet_use_cross_frame_attention:
396
+ hidden_states = (
397
+ self.attn1(
398
+ norm_hidden_states,
399
+ attention_mask=attention_mask,
400
+ video_length=video_length,
401
+ )
402
+ + hidden_states
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
407
+ + hidden_states
408
+ )
409
+
410
+ if self.attn2 is not None:
411
+ # Cross-Attention
412
+ norm_hidden_states = (
413
+ self.norm2(hidden_states, timestep)
414
+ if self.use_ada_layer_norm
415
+ else self.norm2(hidden_states)
416
+ )
417
+ hidden_states = (
418
+ self.attn2(
419
+ norm_hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ + hidden_states
424
+ )
425
+
426
+ # Feed-forward
427
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
428
+
429
+ # Temporal-Attention
430
+ if self.unet_use_temporal_attention:
431
+ d = hidden_states.shape[1]
432
+ hidden_states = rearrange(
433
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
434
+ )
435
+ norm_hidden_states = (
436
+ self.norm_temp(hidden_states, timestep)
437
+ if self.use_ada_layer_norm
438
+ else self.norm_temp(hidden_states)
439
+ )
440
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
441
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
442
+
443
+ return hidden_states
musepose/models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
musepose/models/mutual_self_attention.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from musepose.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ ):
104
+ if self.use_ada_layer_norm: # False
105
+ norm_hidden_states = self.norm1(hidden_states, timestep)
106
+ elif self.use_ada_layer_norm_zero:
107
+ (
108
+ norm_hidden_states,
109
+ gate_msa,
110
+ shift_mlp,
111
+ scale_mlp,
112
+ gate_mlp,
113
+ ) = self.norm1(
114
+ hidden_states,
115
+ timestep,
116
+ class_labels,
117
+ hidden_dtype=hidden_states.dtype,
118
+ )
119
+ else:
120
+ norm_hidden_states = self.norm1(hidden_states)
121
+
122
+ # 1. Self-Attention
123
+ # self.only_cross_attention = False
124
+ cross_attention_kwargs = (
125
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
126
+ )
127
+ if self.only_cross_attention:
128
+ attn_output = self.attn1(
129
+ norm_hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states
131
+ if self.only_cross_attention
132
+ else None,
133
+ attention_mask=attention_mask,
134
+ **cross_attention_kwargs,
135
+ )
136
+ else:
137
+ if MODE == "write":
138
+ self.bank.append(norm_hidden_states.clone())
139
+ attn_output = self.attn1(
140
+ norm_hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states
142
+ if self.only_cross_attention
143
+ else None,
144
+ attention_mask=attention_mask,
145
+ **cross_attention_kwargs,
146
+ )
147
+ if MODE == "read":
148
+ bank_fea = [
149
+ rearrange(
150
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
151
+ "b t l c -> (b t) l c",
152
+ )
153
+ for d in self.bank
154
+ ]
155
+ modify_norm_hidden_states = torch.cat(
156
+ [norm_hidden_states] + bank_fea, dim=1
157
+ )
158
+ hidden_states_uc = (
159
+ self.attn1(
160
+ norm_hidden_states,
161
+ encoder_hidden_states=modify_norm_hidden_states,
162
+ attention_mask=attention_mask,
163
+ )
164
+ + hidden_states
165
+ )
166
+ if do_classifier_free_guidance:
167
+ hidden_states_c = hidden_states_uc.clone()
168
+ _uc_mask = uc_mask.clone()
169
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
170
+ _uc_mask = (
171
+ torch.Tensor(
172
+ [1] * (hidden_states.shape[0] // 2)
173
+ + [0] * (hidden_states.shape[0] // 2)
174
+ )
175
+ .to(device)
176
+ .bool()
177
+ )
178
+ hidden_states_c[_uc_mask] = (
179
+ self.attn1(
180
+ norm_hidden_states[_uc_mask],
181
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
182
+ attention_mask=attention_mask,
183
+ )
184
+ + hidden_states[_uc_mask]
185
+ )
186
+ hidden_states = hidden_states_c.clone()
187
+ else:
188
+ hidden_states = hidden_states_uc
189
+
190
+ # self.bank.clear()
191
+ if self.attn2 is not None:
192
+ # Cross-Attention
193
+ norm_hidden_states = (
194
+ self.norm2(hidden_states, timestep)
195
+ if self.use_ada_layer_norm
196
+ else self.norm2(hidden_states)
197
+ )
198
+ hidden_states = (
199
+ self.attn2(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ attention_mask=attention_mask,
203
+ )
204
+ + hidden_states
205
+ )
206
+
207
+ # Feed-forward
208
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209
+
210
+ # Temporal-Attention
211
+ if self.unet_use_temporal_attention:
212
+ d = hidden_states.shape[1]
213
+ hidden_states = rearrange(
214
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
215
+ )
216
+ norm_hidden_states = (
217
+ self.norm_temp(hidden_states, timestep)
218
+ if self.use_ada_layer_norm
219
+ else self.norm_temp(hidden_states)
220
+ )
221
+ hidden_states = (
222
+ self.attn_temp(norm_hidden_states) + hidden_states
223
+ )
224
+ hidden_states = rearrange(
225
+ hidden_states, "(b d) f c -> (b f) d c", d=d
226
+ )
227
+
228
+ return hidden_states
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = attn_output + hidden_states
233
+
234
+ if self.attn2 is not None:
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+
241
+ # 2. Cross-Attention
242
+ attn_output = self.attn2(
243
+ norm_hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=encoder_attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+ hidden_states = attn_output + hidden_states
249
+
250
+ # 3. Feed-forward
251
+ norm_hidden_states = self.norm3(hidden_states)
252
+
253
+ if self.use_ada_layer_norm_zero:
254
+ norm_hidden_states = (
255
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256
+ )
257
+
258
+ ff_output = self.ff(norm_hidden_states)
259
+
260
+ if self.use_ada_layer_norm_zero:
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+
263
+ hidden_states = ff_output + hidden_states
264
+
265
+ return hidden_states
266
+
267
+ if self.reference_attn:
268
+ if self.fusion_blocks == "midup":
269
+ attn_modules = [
270
+ module
271
+ for module in (
272
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273
+ )
274
+ if isinstance(module, BasicTransformerBlock)
275
+ or isinstance(module, TemporalBasicTransformerBlock)
276
+ ]
277
+ elif self.fusion_blocks == "full":
278
+ attn_modules = [
279
+ module
280
+ for module in torch_dfs(self.unet)
281
+ if isinstance(module, BasicTransformerBlock)
282
+ or isinstance(module, TemporalBasicTransformerBlock)
283
+ ]
284
+ attn_modules = sorted(
285
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286
+ )
287
+
288
+ for i, module in enumerate(attn_modules):
289
+ module._original_inner_forward = module.forward
290
+ if isinstance(module, BasicTransformerBlock):
291
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
292
+ module, BasicTransformerBlock
293
+ )
294
+ if isinstance(module, TemporalBasicTransformerBlock):
295
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
296
+ module, TemporalBasicTransformerBlock
297
+ )
298
+
299
+ module.bank = []
300
+ module.attn_weight = float(i) / float(len(attn_modules))
301
+
302
+ def update(self, writer, dtype=torch.float16):
303
+ if self.reference_attn:
304
+ if self.fusion_blocks == "midup":
305
+ reader_attn_modules = [
306
+ module
307
+ for module in (
308
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309
+ )
310
+ if isinstance(module, TemporalBasicTransformerBlock)
311
+ ]
312
+ writer_attn_modules = [
313
+ module
314
+ for module in (
315
+ torch_dfs(writer.unet.mid_block)
316
+ + torch_dfs(writer.unet.up_blocks)
317
+ )
318
+ if isinstance(module, BasicTransformerBlock)
319
+ ]
320
+ elif self.fusion_blocks == "full":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in torch_dfs(self.unet)
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in torch_dfs(writer.unet)
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ reader_attn_modules = sorted(
332
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+ writer_attn_modules = sorted(
335
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336
+ )
337
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
338
+ r.bank = [v.clone().to(dtype) for v in w.bank]
339
+ # w.bank.clear()
340
+
341
+ def clear(self):
342
+ if self.reference_attn:
343
+ if self.fusion_blocks == "midup":
344
+ reader_attn_modules = [
345
+ module
346
+ for module in (
347
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348
+ )
349
+ if isinstance(module, BasicTransformerBlock)
350
+ or isinstance(module, TemporalBasicTransformerBlock)
351
+ ]
352
+ elif self.fusion_blocks == "full":
353
+ reader_attn_modules = [
354
+ module
355
+ for module in torch_dfs(self.unet)
356
+ if isinstance(module, BasicTransformerBlock)
357
+ or isinstance(module, TemporalBasicTransformerBlock)
358
+ ]
359
+ reader_attn_modules = sorted(
360
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361
+ )
362
+ for r in reader_attn_modules:
363
+ r.bank.clear()
musepose/models/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from musepose.models.motion_module import zero_module
9
+ from musepose.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
musepose/models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227
+
228
+ if temb is not None and self.time_embedding_norm == "default":
229
+ hidden_states = hidden_states + temb
230
+
231
+ hidden_states = self.norm2(hidden_states)
232
+
233
+ if temb is not None and self.time_embedding_norm == "scale_shift":
234
+ scale, shift = torch.chunk(temb, 2, dim=1)
235
+ hidden_states = hidden_states * (1 + scale) + shift
236
+
237
+ hidden_states = self.nonlinearity(hidden_states)
238
+
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.conv2(hidden_states)
241
+
242
+ if self.conv_shortcut is not None:
243
+ input_tensor = self.conv_shortcut(input_tensor)
244
+
245
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246
+
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
musepose/models/transformer_2d.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.normalization import AdaLayerNormSingle
10
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
11
+ from torch import nn
12
+
13
+ from .attention import BasicTransformerBlock
14
+
15
+
16
+ @dataclass
17
+ class Transformer2DModelOutput(BaseOutput):
18
+ """
19
+ The output of [`Transformer2DModel`].
20
+
21
+ Args:
22
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
23
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
24
+ distributions for the unnoised latent pixels.
25
+ """
26
+
27
+ sample: torch.FloatTensor
28
+ ref_feature: torch.FloatTensor
29
+
30
+
31
+ class Transformer2DModel(ModelMixin, ConfigMixin):
32
+ """
33
+ A 2D Transformer model for image-like data.
34
+
35
+ Parameters:
36
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
38
+ in_channels (`int`, *optional*):
39
+ The number of channels in the input and output (specify if the input is **continuous**).
40
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
41
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
42
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
43
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
44
+ This is fixed during training since it is used to learn a number of position embeddings.
45
+ num_vector_embeds (`int`, *optional*):
46
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
47
+ Includes the class for the masked latent pixel.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
49
+ num_embeds_ada_norm ( `int`, *optional*):
50
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
51
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
52
+ added to the hidden states.
53
+
54
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
55
+ attention_bias (`bool`, *optional*):
56
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
57
+ """
58
+
59
+ _supports_gradient_checkpointing = True
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ num_attention_heads: int = 16,
65
+ attention_head_dim: int = 88,
66
+ in_channels: Optional[int] = None,
67
+ out_channels: Optional[int] = None,
68
+ num_layers: int = 1,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ cross_attention_dim: Optional[int] = None,
72
+ attention_bias: bool = False,
73
+ sample_size: Optional[int] = None,
74
+ num_vector_embeds: Optional[int] = None,
75
+ patch_size: Optional[int] = None,
76
+ activation_fn: str = "geglu",
77
+ num_embeds_ada_norm: Optional[int] = None,
78
+ use_linear_projection: bool = False,
79
+ only_cross_attention: bool = False,
80
+ double_self_attention: bool = False,
81
+ upcast_attention: bool = False,
82
+ norm_type: str = "layer_norm",
83
+ norm_elementwise_affine: bool = True,
84
+ norm_eps: float = 1e-5,
85
+ attention_type: str = "default",
86
+ caption_channels: int = None,
87
+ ):
88
+ super().__init__()
89
+ self.use_linear_projection = use_linear_projection
90
+ self.num_attention_heads = num_attention_heads
91
+ self.attention_head_dim = attention_head_dim
92
+ inner_dim = num_attention_heads * attention_head_dim
93
+
94
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
95
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
96
+
97
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
98
+ # Define whether input is continuous or discrete depending on configuration
99
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
100
+ self.is_input_vectorized = num_vector_embeds is not None
101
+ self.is_input_patches = in_channels is not None and patch_size is not None
102
+
103
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
104
+ deprecation_message = (
105
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
106
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
107
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
108
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
109
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
110
+ )
111
+ deprecate(
112
+ "norm_type!=num_embeds_ada_norm",
113
+ "1.0.0",
114
+ deprecation_message,
115
+ standard_warn=False,
116
+ )
117
+ norm_type = "ada_norm"
118
+
119
+ if self.is_input_continuous and self.is_input_vectorized:
120
+ raise ValueError(
121
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
122
+ " sure that either `in_channels` or `num_vector_embeds` is None."
123
+ )
124
+ elif self.is_input_vectorized and self.is_input_patches:
125
+ raise ValueError(
126
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
127
+ " sure that either `num_vector_embeds` or `num_patches` is None."
128
+ )
129
+ elif (
130
+ not self.is_input_continuous
131
+ and not self.is_input_vectorized
132
+ and not self.is_input_patches
133
+ ):
134
+ raise ValueError(
135
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
136
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
137
+ )
138
+
139
+ # 2. Define input layers
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(
143
+ num_groups=norm_num_groups,
144
+ num_channels=in_channels,
145
+ eps=1e-6,
146
+ affine=True,
147
+ )
148
+ if use_linear_projection:
149
+ self.proj_in = linear_cls(in_channels, inner_dim)
150
+ else:
151
+ self.proj_in = conv_cls(
152
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ # 3. Define transformers blocks
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ BasicTransformerBlock(
159
+ inner_dim,
160
+ num_attention_heads,
161
+ attention_head_dim,
162
+ dropout=dropout,
163
+ cross_attention_dim=cross_attention_dim,
164
+ activation_fn=activation_fn,
165
+ num_embeds_ada_norm=num_embeds_ada_norm,
166
+ attention_bias=attention_bias,
167
+ only_cross_attention=only_cross_attention,
168
+ double_self_attention=double_self_attention,
169
+ upcast_attention=upcast_attention,
170
+ norm_type=norm_type,
171
+ norm_elementwise_affine=norm_elementwise_affine,
172
+ norm_eps=norm_eps,
173
+ attention_type=attention_type,
174
+ )
175
+ for d in range(num_layers)
176
+ ]
177
+ )
178
+
179
+ # 4. Define output layers
180
+ self.out_channels = in_channels if out_channels is None else out_channels
181
+ # TODO: should use out_channels for continuous projections
182
+ if use_linear_projection:
183
+ self.proj_out = linear_cls(inner_dim, in_channels)
184
+ else:
185
+ self.proj_out = conv_cls(
186
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
187
+ )
188
+
189
+ # 5. PixArt-Alpha blocks.
190
+ self.adaln_single = None
191
+ self.use_additional_conditions = False
192
+ if norm_type == "ada_norm_single":
193
+ self.use_additional_conditions = self.config.sample_size == 128
194
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
195
+ # additional conditions until we find better name
196
+ self.adaln_single = AdaLayerNormSingle(
197
+ inner_dim, use_additional_conditions=self.use_additional_conditions
198
+ )
199
+
200
+ self.caption_projection = None
201
+ if caption_channels is not None:
202
+ self.caption_projection = CaptionProjection(
203
+ in_features=caption_channels, hidden_size=inner_dim
204
+ )
205
+
206
+ self.gradient_checkpointing = False
207
+
208
+ def _set_gradient_checkpointing(self, module, value=False):
209
+ if hasattr(module, "gradient_checkpointing"):
210
+ module.gradient_checkpointing = value
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ encoder_hidden_states: Optional[torch.Tensor] = None,
216
+ timestep: Optional[torch.LongTensor] = None,
217
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
218
+ class_labels: Optional[torch.LongTensor] = None,
219
+ cross_attention_kwargs: Dict[str, Any] = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ encoder_attention_mask: Optional[torch.Tensor] = None,
222
+ return_dict: bool = True,
223
+ ):
224
+ """
225
+ The [`Transformer2DModel`] forward method.
226
+
227
+ Args:
228
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
229
+ Input `hidden_states`.
230
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
+ self-attention.
233
+ timestep ( `torch.LongTensor`, *optional*):
234
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
237
+ `AdaLayerZeroNorm`.
238
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
239
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
240
+ `self.processor` in
241
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
242
+ attention_mask ( `torch.Tensor`, *optional*):
243
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
244
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
245
+ negative values to the attention scores corresponding to "discard" tokens.
246
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
247
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
248
+
249
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
250
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
251
+
252
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
253
+ above. This bias will be added to the cross-attention scores.
254
+ return_dict (`bool`, *optional*, defaults to `True`):
255
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
256
+ tuple.
257
+
258
+ Returns:
259
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
260
+ `tuple` where the first element is the sample tensor.
261
+ """
262
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
263
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
264
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
265
+ # expects mask of shape:
266
+ # [batch, key_tokens]
267
+ # adds singleton query_tokens dimension:
268
+ # [batch, 1, key_tokens]
269
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
270
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
271
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
272
+ if attention_mask is not None and attention_mask.ndim == 2:
273
+ # assume that mask is expressed as:
274
+ # (1 = keep, 0 = discard)
275
+ # convert mask into a bias that can be added to attention scores:
276
+ # (keep = +0, discard = -10000.0)
277
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
278
+ attention_mask = attention_mask.unsqueeze(1)
279
+
280
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
281
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
282
+ encoder_attention_mask = (
283
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
284
+ ) * -10000.0
285
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
286
+
287
+ # Retrieve lora scale.
288
+ lora_scale = (
289
+ cross_attention_kwargs.get("scale", 1.0)
290
+ if cross_attention_kwargs is not None
291
+ else 1.0
292
+ )
293
+
294
+ # 1. Input
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = (
301
+ self.proj_in(hidden_states, scale=lora_scale)
302
+ if not USE_PEFT_BACKEND
303
+ else self.proj_in(hidden_states)
304
+ )
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
307
+ batch, height * width, inner_dim
308
+ )
309
+ else:
310
+ inner_dim = hidden_states.shape[1]
311
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
312
+ batch, height * width, inner_dim
313
+ )
314
+ hidden_states = (
315
+ self.proj_in(hidden_states, scale=lora_scale)
316
+ if not USE_PEFT_BACKEND
317
+ else self.proj_in(hidden_states)
318
+ )
319
+
320
+ # 2. Blocks
321
+ if self.caption_projection is not None:
322
+ batch_size = hidden_states.shape[0]
323
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
324
+ encoder_hidden_states = encoder_hidden_states.view(
325
+ batch_size, -1, hidden_states.shape[-1]
326
+ )
327
+
328
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
329
+ for block in self.transformer_blocks:
330
+ if self.training and self.gradient_checkpointing:
331
+
332
+ def create_custom_forward(module, return_dict=None):
333
+ def custom_forward(*inputs):
334
+ if return_dict is not None:
335
+ return module(*inputs, return_dict=return_dict)
336
+ else:
337
+ return module(*inputs)
338
+
339
+ return custom_forward
340
+
341
+ ckpt_kwargs: Dict[str, Any] = (
342
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
343
+ )
344
+ hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ attention_mask,
348
+ encoder_hidden_states,
349
+ encoder_attention_mask,
350
+ timestep,
351
+ cross_attention_kwargs,
352
+ class_labels,
353
+ **ckpt_kwargs,
354
+ )
355
+ else:
356
+ hidden_states = block(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ encoder_attention_mask=encoder_attention_mask,
361
+ timestep=timestep,
362
+ cross_attention_kwargs=cross_attention_kwargs,
363
+ class_labels=class_labels,
364
+ )
365
+
366
+ # 3. Output
367
+ if self.is_input_continuous:
368
+ if not self.use_linear_projection:
369
+ hidden_states = (
370
+ hidden_states.reshape(batch, height, width, inner_dim)
371
+ .permute(0, 3, 1, 2)
372
+ .contiguous()
373
+ )
374
+ hidden_states = (
375
+ self.proj_out(hidden_states, scale=lora_scale)
376
+ if not USE_PEFT_BACKEND
377
+ else self.proj_out(hidden_states)
378
+ )
379
+ else:
380
+ hidden_states = (
381
+ self.proj_out(hidden_states, scale=lora_scale)
382
+ if not USE_PEFT_BACKEND
383
+ else self.proj_out(hidden_states)
384
+ )
385
+ hidden_states = (
386
+ hidden_states.reshape(batch, height, width, inner_dim)
387
+ .permute(0, 3, 1, 2)
388
+ .contiguous()
389
+ )
390
+
391
+ output = hidden_states + residual
392
+ if not return_dict:
393
+ return (output, ref_feature)
394
+
395
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
musepose/models/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
musepose/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
musepose/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ Timesteps,
28
+ )
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ BaseOutput,
33
+ deprecate,
34
+ logging,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2D,
41
+ UNetMidBlock2DCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet2DConditionOutput(BaseOutput):
51
+ """
52
+ The output of [`UNet2DConditionModel`].
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
56
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor = None
60
+ ref_features: Tuple[torch.FloatTensor] = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, normalization and activation layers is skipped in post-processing.
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
105
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
145
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
146
+ *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = (
176
+ "UpBlock2D",
177
+ "CrossAttnUpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ ),
181
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
182
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
183
+ layers_per_block: Union[int, Tuple[int]] = 2,
184
+ downsample_padding: int = 1,
185
+ mid_block_scale_factor: float = 1,
186
+ dropout: float = 0.0,
187
+ act_fn: str = "silu",
188
+ norm_num_groups: Optional[int] = 32,
189
+ norm_eps: float = 1e-5,
190
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
191
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
192
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
193
+ encoder_hid_dim: Optional[int] = None,
194
+ encoder_hid_dim_type: Optional[str] = None,
195
+ attention_head_dim: Union[int, Tuple[int]] = 8,
196
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
197
+ dual_cross_attention: bool = False,
198
+ use_linear_projection: bool = False,
199
+ class_embed_type: Optional[str] = None,
200
+ addition_embed_type: Optional[str] = None,
201
+ addition_time_embed_dim: Optional[int] = None,
202
+ num_class_embeds: Optional[int] = None,
203
+ upcast_attention: bool = False,
204
+ resnet_time_scale_shift: str = "default",
205
+ resnet_skip_time_act: bool = False,
206
+ resnet_out_scale_factor: int = 1.0,
207
+ time_embedding_type: str = "positional",
208
+ time_embedding_dim: Optional[int] = None,
209
+ time_embedding_act_fn: Optional[str] = None,
210
+ timestep_post_act: Optional[str] = None,
211
+ time_cond_proj_dim: Optional[int] = None,
212
+ conv_in_kernel: int = 3,
213
+ conv_out_kernel: int = 3,
214
+ projection_class_embeddings_input_dim: Optional[int] = None,
215
+ attention_type: str = "default",
216
+ class_embeddings_concat: bool = False,
217
+ mid_block_only_cross_attention: Optional[bool] = None,
218
+ cross_attention_norm: Optional[str] = None,
219
+ addition_embed_type_num_heads=64,
220
+ ):
221
+ super().__init__()
222
+
223
+ self.sample_size = sample_size
224
+
225
+ if num_attention_heads is not None:
226
+ raise ValueError(
227
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
228
+ )
229
+
230
+ # If `num_attention_heads` is not defined (which is the case for most models)
231
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
232
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
233
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(down_block_types) != len(up_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
242
+ )
243
+
244
+ if len(block_out_channels) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(only_cross_attention, bool) and len(
250
+ only_cross_attention
251
+ ) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
257
+ down_block_types
258
+ ):
259
+ raise ValueError(
260
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
264
+ down_block_types
265
+ ):
266
+ raise ValueError(
267
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
271
+ down_block_types
272
+ ):
273
+ raise ValueError(
274
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
278
+ down_block_types
279
+ ):
280
+ raise ValueError(
281
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
282
+ )
283
+ if (
284
+ isinstance(transformer_layers_per_block, list)
285
+ and reverse_transformer_layers_per_block is None
286
+ ):
287
+ for layer_number_per_block in transformer_layers_per_block:
288
+ if isinstance(layer_number_per_block, list):
289
+ raise ValueError(
290
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
291
+ )
292
+
293
+ # input
294
+ conv_in_padding = (conv_in_kernel - 1) // 2
295
+ self.conv_in = nn.Conv2d(
296
+ in_channels,
297
+ block_out_channels[0],
298
+ kernel_size=conv_in_kernel,
299
+ padding=conv_in_padding,
300
+ )
301
+
302
+ # time
303
+ if time_embedding_type == "fourier":
304
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
305
+ if time_embed_dim % 2 != 0:
306
+ raise ValueError(
307
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
308
+ )
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2,
311
+ set_W_to_weight=False,
312
+ log=False,
313
+ flip_sin_to_cos=flip_sin_to_cos,
314
+ )
315
+ timestep_input_dim = time_embed_dim
316
+ elif time_embedding_type == "positional":
317
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
318
+
319
+ self.time_proj = Timesteps(
320
+ block_out_channels[0], flip_sin_to_cos, freq_shift
321
+ )
322
+ timestep_input_dim = block_out_channels[0]
323
+ else:
324
+ raise ValueError(
325
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
326
+ )
327
+
328
+ self.time_embedding = TimestepEmbedding(
329
+ timestep_input_dim,
330
+ time_embed_dim,
331
+ act_fn=act_fn,
332
+ post_act_fn=timestep_post_act,
333
+ cond_proj_dim=time_cond_proj_dim,
334
+ )
335
+
336
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
337
+ encoder_hid_dim_type = "text_proj"
338
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
339
+ logger.info(
340
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
341
+ )
342
+
343
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
344
+ raise ValueError(
345
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
346
+ )
347
+
348
+ if encoder_hid_dim_type == "text_proj":
349
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
350
+ elif encoder_hid_dim_type == "text_image_proj":
351
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
352
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
353
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
354
+ self.encoder_hid_proj = TextImageProjection(
355
+ text_embed_dim=encoder_hid_dim,
356
+ image_embed_dim=cross_attention_dim,
357
+ cross_attention_dim=cross_attention_dim,
358
+ )
359
+ elif encoder_hid_dim_type == "image_proj":
360
+ # Kandinsky 2.2
361
+ self.encoder_hid_proj = ImageProjection(
362
+ image_embed_dim=encoder_hid_dim,
363
+ cross_attention_dim=cross_attention_dim,
364
+ )
365
+ elif encoder_hid_dim_type is not None:
366
+ raise ValueError(
367
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
368
+ )
369
+ else:
370
+ self.encoder_hid_proj = None
371
+
372
+ # class embedding
373
+ if class_embed_type is None and num_class_embeds is not None:
374
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
375
+ elif class_embed_type == "timestep":
376
+ self.class_embedding = TimestepEmbedding(
377
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
378
+ )
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ elif class_embed_type == "simple_projection":
397
+ if projection_class_embeddings_input_dim is None:
398
+ raise ValueError(
399
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
400
+ )
401
+ self.class_embedding = nn.Linear(
402
+ projection_class_embeddings_input_dim, time_embed_dim
403
+ )
404
+ else:
405
+ self.class_embedding = None
406
+
407
+ if addition_embed_type == "text":
408
+ if encoder_hid_dim is not None:
409
+ text_time_embedding_from_dim = encoder_hid_dim
410
+ else:
411
+ text_time_embedding_from_dim = cross_attention_dim
412
+
413
+ self.add_embedding = TextTimeEmbedding(
414
+ text_time_embedding_from_dim,
415
+ time_embed_dim,
416
+ num_heads=addition_embed_type_num_heads,
417
+ )
418
+ elif addition_embed_type == "text_image":
419
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
420
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
421
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
422
+ self.add_embedding = TextImageTimeEmbedding(
423
+ text_embed_dim=cross_attention_dim,
424
+ image_embed_dim=cross_attention_dim,
425
+ time_embed_dim=time_embed_dim,
426
+ )
427
+ elif addition_embed_type == "text_time":
428
+ self.add_time_proj = Timesteps(
429
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
430
+ )
431
+ self.add_embedding = TimestepEmbedding(
432
+ projection_class_embeddings_input_dim, time_embed_dim
433
+ )
434
+ elif addition_embed_type == "image":
435
+ # Kandinsky 2.2
436
+ self.add_embedding = ImageTimeEmbedding(
437
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
438
+ )
439
+ elif addition_embed_type == "image_hint":
440
+ # Kandinsky 2.2 ControlNet
441
+ self.add_embedding = ImageHintTimeEmbedding(
442
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
443
+ )
444
+ elif addition_embed_type is not None:
445
+ raise ValueError(
446
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
447
+ )
448
+
449
+ if time_embedding_act_fn is None:
450
+ self.time_embed_act = None
451
+ else:
452
+ self.time_embed_act = get_activation(time_embedding_act_fn)
453
+
454
+ self.down_blocks = nn.ModuleList([])
455
+ self.up_blocks = nn.ModuleList([])
456
+
457
+ if isinstance(only_cross_attention, bool):
458
+ if mid_block_only_cross_attention is None:
459
+ mid_block_only_cross_attention = only_cross_attention
460
+
461
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
462
+
463
+ if mid_block_only_cross_attention is None:
464
+ mid_block_only_cross_attention = False
465
+
466
+ if isinstance(num_attention_heads, int):
467
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
468
+
469
+ if isinstance(attention_head_dim, int):
470
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
471
+
472
+ if isinstance(cross_attention_dim, int):
473
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
474
+
475
+ if isinstance(layers_per_block, int):
476
+ layers_per_block = [layers_per_block] * len(down_block_types)
477
+
478
+ if isinstance(transformer_layers_per_block, int):
479
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
480
+ down_block_types
481
+ )
482
+
483
+ if class_embeddings_concat:
484
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
485
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
486
+ # regular time embeddings
487
+ blocks_time_embed_dim = time_embed_dim * 2
488
+ else:
489
+ blocks_time_embed_dim = time_embed_dim
490
+
491
+ # down
492
+ output_channel = block_out_channels[0]
493
+ for i, down_block_type in enumerate(down_block_types):
494
+ input_channel = output_channel
495
+ output_channel = block_out_channels[i]
496
+ is_final_block = i == len(block_out_channels) - 1
497
+
498
+ down_block = get_down_block(
499
+ down_block_type,
500
+ num_layers=layers_per_block[i],
501
+ transformer_layers_per_block=transformer_layers_per_block[i],
502
+ in_channels=input_channel,
503
+ out_channels=output_channel,
504
+ temb_channels=blocks_time_embed_dim,
505
+ add_downsample=not is_final_block,
506
+ resnet_eps=norm_eps,
507
+ resnet_act_fn=act_fn,
508
+ resnet_groups=norm_num_groups,
509
+ cross_attention_dim=cross_attention_dim[i],
510
+ num_attention_heads=num_attention_heads[i],
511
+ downsample_padding=downsample_padding,
512
+ dual_cross_attention=dual_cross_attention,
513
+ use_linear_projection=use_linear_projection,
514
+ only_cross_attention=only_cross_attention[i],
515
+ upcast_attention=upcast_attention,
516
+ resnet_time_scale_shift=resnet_time_scale_shift,
517
+ attention_type=attention_type,
518
+ resnet_skip_time_act=resnet_skip_time_act,
519
+ resnet_out_scale_factor=resnet_out_scale_factor,
520
+ cross_attention_norm=cross_attention_norm,
521
+ attention_head_dim=attention_head_dim[i]
522
+ if attention_head_dim[i] is not None
523
+ else output_channel,
524
+ dropout=dropout,
525
+ )
526
+ self.down_blocks.append(down_block)
527
+
528
+ # mid
529
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
530
+ self.mid_block = UNetMidBlock2DCrossAttn(
531
+ transformer_layers_per_block=transformer_layers_per_block[-1],
532
+ in_channels=block_out_channels[-1],
533
+ temb_channels=blocks_time_embed_dim,
534
+ dropout=dropout,
535
+ resnet_eps=norm_eps,
536
+ resnet_act_fn=act_fn,
537
+ output_scale_factor=mid_block_scale_factor,
538
+ resnet_time_scale_shift=resnet_time_scale_shift,
539
+ cross_attention_dim=cross_attention_dim[-1],
540
+ num_attention_heads=num_attention_heads[-1],
541
+ resnet_groups=norm_num_groups,
542
+ dual_cross_attention=dual_cross_attention,
543
+ use_linear_projection=use_linear_projection,
544
+ upcast_attention=upcast_attention,
545
+ attention_type=attention_type,
546
+ )
547
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
548
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
549
+ elif mid_block_type == "UNetMidBlock2D":
550
+ self.mid_block = UNetMidBlock2D(
551
+ in_channels=block_out_channels[-1],
552
+ temb_channels=blocks_time_embed_dim,
553
+ dropout=dropout,
554
+ num_layers=0,
555
+ resnet_eps=norm_eps,
556
+ resnet_act_fn=act_fn,
557
+ output_scale_factor=mid_block_scale_factor,
558
+ resnet_groups=norm_num_groups,
559
+ resnet_time_scale_shift=resnet_time_scale_shift,
560
+ add_attention=False,
561
+ )
562
+ elif mid_block_type is None:
563
+ self.mid_block = None
564
+ else:
565
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
566
+
567
+ # count how many layers upsample the images
568
+ self.num_upsamplers = 0
569
+
570
+ # up
571
+ reversed_block_out_channels = list(reversed(block_out_channels))
572
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
573
+ reversed_layers_per_block = list(reversed(layers_per_block))
574
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
575
+ reversed_transformer_layers_per_block = (
576
+ list(reversed(transformer_layers_per_block))
577
+ if reverse_transformer_layers_per_block is None
578
+ else reverse_transformer_layers_per_block
579
+ )
580
+ only_cross_attention = list(reversed(only_cross_attention))
581
+
582
+ output_channel = reversed_block_out_channels[0]
583
+ for i, up_block_type in enumerate(up_block_types):
584
+ is_final_block = i == len(block_out_channels) - 1
585
+
586
+ prev_output_channel = output_channel
587
+ output_channel = reversed_block_out_channels[i]
588
+ input_channel = reversed_block_out_channels[
589
+ min(i + 1, len(block_out_channels) - 1)
590
+ ]
591
+
592
+ # add upsample block for all BUT final layer
593
+ if not is_final_block:
594
+ add_upsample = True
595
+ self.num_upsamplers += 1
596
+ else:
597
+ add_upsample = False
598
+
599
+ up_block = get_up_block(
600
+ up_block_type,
601
+ num_layers=reversed_layers_per_block[i] + 1,
602
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
603
+ in_channels=input_channel,
604
+ out_channels=output_channel,
605
+ prev_output_channel=prev_output_channel,
606
+ temb_channels=blocks_time_embed_dim,
607
+ add_upsample=add_upsample,
608
+ resnet_eps=norm_eps,
609
+ resnet_act_fn=act_fn,
610
+ resolution_idx=i,
611
+ resnet_groups=norm_num_groups,
612
+ cross_attention_dim=reversed_cross_attention_dim[i],
613
+ num_attention_heads=reversed_num_attention_heads[i],
614
+ dual_cross_attention=dual_cross_attention,
615
+ use_linear_projection=use_linear_projection,
616
+ only_cross_attention=only_cross_attention[i],
617
+ upcast_attention=upcast_attention,
618
+ resnet_time_scale_shift=resnet_time_scale_shift,
619
+ attention_type=attention_type,
620
+ resnet_skip_time_act=resnet_skip_time_act,
621
+ resnet_out_scale_factor=resnet_out_scale_factor,
622
+ cross_attention_norm=cross_attention_norm,
623
+ attention_head_dim=attention_head_dim[i]
624
+ if attention_head_dim[i] is not None
625
+ else output_channel,
626
+ dropout=dropout,
627
+ )
628
+ self.up_blocks.append(up_block)
629
+ prev_output_channel = output_channel
630
+
631
+ # out
632
+ if norm_num_groups is not None:
633
+ self.conv_norm_out = nn.GroupNorm(
634
+ num_channels=block_out_channels[0],
635
+ num_groups=norm_num_groups,
636
+ eps=norm_eps,
637
+ )
638
+
639
+ self.conv_act = get_activation(act_fn)
640
+
641
+ else:
642
+ self.conv_norm_out = None
643
+ self.conv_act = None
644
+ self.conv_norm_out = None
645
+
646
+ conv_out_padding = (conv_out_kernel - 1) // 2
647
+ # self.conv_out = nn.Conv2d(
648
+ # block_out_channels[0],
649
+ # out_channels,
650
+ # kernel_size=conv_out_kernel,
651
+ # padding=conv_out_padding,
652
+ # )
653
+
654
+ if attention_type in ["gated", "gated-text-image"]:
655
+ positive_len = 768
656
+ if isinstance(cross_attention_dim, int):
657
+ positive_len = cross_attention_dim
658
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
659
+ cross_attention_dim, list
660
+ ):
661
+ positive_len = cross_attention_dim[0]
662
+
663
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
664
+ self.position_net = PositionNet(
665
+ positive_len=positive_len,
666
+ out_dim=cross_attention_dim,
667
+ feature_type=feature_type,
668
+ )
669
+
670
+ @property
671
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
672
+ r"""
673
+ Returns:
674
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
675
+ indexed by its weight name.
676
+ """
677
+ # set recursively
678
+ processors = {}
679
+
680
+ def fn_recursive_add_processors(
681
+ name: str,
682
+ module: torch.nn.Module,
683
+ processors: Dict[str, AttentionProcessor],
684
+ ):
685
+ if hasattr(module, "get_processor"):
686
+ processors[f"{name}.processor"] = module.get_processor(
687
+ return_deprecated_lora=True
688
+ )
689
+
690
+ for sub_name, child in module.named_children():
691
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
692
+
693
+ return processors
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_add_processors(name, module, processors)
697
+
698
+ return processors
699
+
700
+ def set_attn_processor(
701
+ self,
702
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
703
+ _remove_lora=False,
704
+ ):
705
+ r"""
706
+ Sets the attention processor to use to compute attention.
707
+
708
+ Parameters:
709
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
710
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
711
+ for **all** `Attention` layers.
712
+
713
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
714
+ processor. This is strongly recommended when setting trainable attention processors.
715
+
716
+ """
717
+ count = len(self.attn_processors.keys())
718
+
719
+ if isinstance(processor, dict) and len(processor) != count:
720
+ raise ValueError(
721
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
722
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
723
+ )
724
+
725
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
726
+ if hasattr(module, "set_processor"):
727
+ if not isinstance(processor, dict):
728
+ module.set_processor(processor, _remove_lora=_remove_lora)
729
+ else:
730
+ module.set_processor(
731
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
732
+ )
733
+
734
+ for sub_name, child in module.named_children():
735
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
736
+
737
+ for name, module in self.named_children():
738
+ fn_recursive_attn_processor(name, module, processor)
739
+
740
+ def set_default_attn_processor(self):
741
+ """
742
+ Disables custom attention processors and sets the default attention implementation.
743
+ """
744
+ if all(
745
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
746
+ for proc in self.attn_processors.values()
747
+ ):
748
+ processor = AttnAddedKVProcessor()
749
+ elif all(
750
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
751
+ for proc in self.attn_processors.values()
752
+ ):
753
+ processor = AttnProcessor()
754
+ else:
755
+ raise ValueError(
756
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
757
+ )
758
+
759
+ self.set_attn_processor(processor, _remove_lora=True)
760
+
761
+ def set_attention_slice(self, slice_size):
762
+ r"""
763
+ Enable sliced attention computation.
764
+
765
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
766
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
767
+
768
+ Args:
769
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
770
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
771
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
772
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
773
+ must be a multiple of `slice_size`.
774
+ """
775
+ sliceable_head_dims = []
776
+
777
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
778
+ if hasattr(module, "set_attention_slice"):
779
+ sliceable_head_dims.append(module.sliceable_head_dim)
780
+
781
+ for child in module.children():
782
+ fn_recursive_retrieve_sliceable_dims(child)
783
+
784
+ # retrieve number of attention layers
785
+ for module in self.children():
786
+ fn_recursive_retrieve_sliceable_dims(module)
787
+
788
+ num_sliceable_layers = len(sliceable_head_dims)
789
+
790
+ if slice_size == "auto":
791
+ # half the attention head size is usually a good trade-off between
792
+ # speed and memory
793
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
794
+ elif slice_size == "max":
795
+ # make smallest slice possible
796
+ slice_size = num_sliceable_layers * [1]
797
+
798
+ slice_size = (
799
+ num_sliceable_layers * [slice_size]
800
+ if not isinstance(slice_size, list)
801
+ else slice_size
802
+ )
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(
820
+ module: torch.nn.Module, slice_size: List[int]
821
+ ):
822
+ if hasattr(module, "set_attention_slice"):
823
+ module.set_attention_slice(slice_size.pop())
824
+
825
+ for child in module.children():
826
+ fn_recursive_set_attention_slice(child, slice_size)
827
+
828
+ reversed_slice_size = list(reversed(slice_size))
829
+ for module in self.children():
830
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
831
+
832
+ def _set_gradient_checkpointing(self, module, value=False):
833
+ if hasattr(module, "gradient_checkpointing"):
834
+ module.gradient_checkpointing = value
835
+
836
+ def enable_freeu(self, s1, s2, b1, b2):
837
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
838
+
839
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
840
+
841
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
842
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
843
+
844
+ Args:
845
+ s1 (`float`):
846
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
847
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
848
+ s2 (`float`):
849
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
850
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
851
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
852
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
853
+ """
854
+ for i, upsample_block in enumerate(self.up_blocks):
855
+ setattr(upsample_block, "s1", s1)
856
+ setattr(upsample_block, "s2", s2)
857
+ setattr(upsample_block, "b1", b1)
858
+ setattr(upsample_block, "b2", b2)
859
+
860
+ def disable_freeu(self):
861
+ """Disables the FreeU mechanism."""
862
+ freeu_keys = {"s1", "s2", "b1", "b2"}
863
+ for i, upsample_block in enumerate(self.up_blocks):
864
+ for k in freeu_keys:
865
+ if (
866
+ hasattr(upsample_block, k)
867
+ or getattr(upsample_block, k, None) is not None
868
+ ):
869
+ setattr(upsample_block, k, None)
870
+
871
+ def forward(
872
+ self,
873
+ sample: torch.FloatTensor,
874
+ timestep: Union[torch.Tensor, float, int],
875
+ encoder_hidden_states: torch.Tensor,
876
+ class_labels: Optional[torch.Tensor] = None,
877
+ timestep_cond: Optional[torch.Tensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
880
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
881
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
882
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
883
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
884
+ encoder_attention_mask: Optional[torch.Tensor] = None,
885
+ return_dict: bool = True,
886
+ ) -> Union[UNet2DConditionOutput, Tuple]:
887
+ r"""
888
+ The [`UNet2DConditionModel`] forward method.
889
+
890
+ Args:
891
+ sample (`torch.FloatTensor`):
892
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
893
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
894
+ encoder_hidden_states (`torch.FloatTensor`):
895
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
896
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
897
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
898
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
899
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
900
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
901
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
902
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
903
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
904
+ negative values to the attention scores corresponding to "discard" tokens.
905
+ cross_attention_kwargs (`dict`, *optional*):
906
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
907
+ `self.processor` in
908
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
909
+ added_cond_kwargs: (`dict`, *optional*):
910
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
911
+ are passed along to the UNet blocks.
912
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
913
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
914
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
915
+ A tensor that if specified is added to the residual of the middle unet block.
916
+ encoder_attention_mask (`torch.Tensor`):
917
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
918
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
919
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
920
+ return_dict (`bool`, *optional*, defaults to `True`):
921
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
922
+ tuple.
923
+ cross_attention_kwargs (`dict`, *optional*):
924
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
925
+ added_cond_kwargs: (`dict`, *optional*):
926
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
927
+ are passed along to the UNet blocks.
928
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
929
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
930
+ example from ControlNet side model(s)
931
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
932
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
933
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
934
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
935
+
936
+ Returns:
937
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
938
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
939
+ a `tuple` is returned where the first element is the sample tensor.
940
+ """
941
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
942
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
943
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
944
+ # on the fly if necessary.
945
+ default_overall_up_factor = 2**self.num_upsamplers
946
+
947
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
948
+ forward_upsample_size = False
949
+ upsample_size = None
950
+
951
+ for dim in sample.shape[-2:]:
952
+ if dim % default_overall_up_factor != 0:
953
+ # Forward upsample size to force interpolation output size.
954
+ forward_upsample_size = True
955
+ break
956
+
957
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
958
+ # expects mask of shape:
959
+ # [batch, key_tokens]
960
+ # adds singleton query_tokens dimension:
961
+ # [batch, 1, key_tokens]
962
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
963
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
964
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
965
+ if attention_mask is not None:
966
+ # assume that mask is expressed as:
967
+ # (1 = keep, 0 = discard)
968
+ # convert mask into a bias that can be added to attention scores:
969
+ # (keep = +0, discard = -10000.0)
970
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
971
+ attention_mask = attention_mask.unsqueeze(1)
972
+
973
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
974
+ if encoder_attention_mask is not None:
975
+ encoder_attention_mask = (
976
+ 1 - encoder_attention_mask.to(sample.dtype)
977
+ ) * -10000.0
978
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
979
+
980
+ # 0. center input if necessary
981
+ if self.config.center_input_sample:
982
+ sample = 2 * sample - 1.0
983
+
984
+ # 1. time
985
+ timesteps = timestep
986
+ if not torch.is_tensor(timesteps):
987
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
988
+ # This would be a good case for the `match` statement (Python 3.10+)
989
+ is_mps = sample.device.type == "mps"
990
+ if isinstance(timestep, float):
991
+ dtype = torch.float32 if is_mps else torch.float64
992
+ else:
993
+ dtype = torch.int32 if is_mps else torch.int64
994
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
995
+ elif len(timesteps.shape) == 0:
996
+ timesteps = timesteps[None].to(sample.device)
997
+
998
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
999
+ timesteps = timesteps.expand(sample.shape[0])
1000
+
1001
+ t_emb = self.time_proj(timesteps)
1002
+
1003
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1004
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1005
+ # there might be better ways to encapsulate this.
1006
+ t_emb = t_emb.to(dtype=sample.dtype)
1007
+
1008
+ emb = self.time_embedding(t_emb, timestep_cond)
1009
+ aug_emb = None
1010
+
1011
+ if self.class_embedding is not None:
1012
+ if class_labels is None:
1013
+ raise ValueError(
1014
+ "class_labels should be provided when num_class_embeds > 0"
1015
+ )
1016
+
1017
+ if self.config.class_embed_type == "timestep":
1018
+ class_labels = self.time_proj(class_labels)
1019
+
1020
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1021
+ # there might be better ways to encapsulate this.
1022
+ class_labels = class_labels.to(dtype=sample.dtype)
1023
+
1024
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1025
+
1026
+ if self.config.class_embeddings_concat:
1027
+ emb = torch.cat([emb, class_emb], dim=-1)
1028
+ else:
1029
+ emb = emb + class_emb
1030
+
1031
+ if self.config.addition_embed_type == "text":
1032
+ aug_emb = self.add_embedding(encoder_hidden_states)
1033
+ elif self.config.addition_embed_type == "text_image":
1034
+ # Kandinsky 2.1 - style
1035
+ if "image_embeds" not in added_cond_kwargs:
1036
+ raise ValueError(
1037
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1038
+ )
1039
+
1040
+ image_embs = added_cond_kwargs.get("image_embeds")
1041
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1042
+ aug_emb = self.add_embedding(text_embs, image_embs)
1043
+ elif self.config.addition_embed_type == "text_time":
1044
+ # SDXL - style
1045
+ if "text_embeds" not in added_cond_kwargs:
1046
+ raise ValueError(
1047
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1048
+ )
1049
+ text_embeds = added_cond_kwargs.get("text_embeds")
1050
+ if "time_ids" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ time_ids = added_cond_kwargs.get("time_ids")
1055
+ time_embeds = self.add_time_proj(time_ids.flatten())
1056
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1057
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1058
+ add_embeds = add_embeds.to(emb.dtype)
1059
+ aug_emb = self.add_embedding(add_embeds)
1060
+ elif self.config.addition_embed_type == "image":
1061
+ # Kandinsky 2.2 - style
1062
+ if "image_embeds" not in added_cond_kwargs:
1063
+ raise ValueError(
1064
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1065
+ )
1066
+ image_embs = added_cond_kwargs.get("image_embeds")
1067
+ aug_emb = self.add_embedding(image_embs)
1068
+ elif self.config.addition_embed_type == "image_hint":
1069
+ # Kandinsky 2.2 - style
1070
+ if (
1071
+ "image_embeds" not in added_cond_kwargs
1072
+ or "hint" not in added_cond_kwargs
1073
+ ):
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if (
1088
+ self.encoder_hid_proj is not None
1089
+ and self.config.encoder_hid_dim_type == "text_proj"
1090
+ ):
1091
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1092
+ elif (
1093
+ self.encoder_hid_proj is not None
1094
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1095
+ ):
1096
+ # Kadinsky 2.1 - style
1097
+ if "image_embeds" not in added_cond_kwargs:
1098
+ raise ValueError(
1099
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1100
+ )
1101
+
1102
+ image_embeds = added_cond_kwargs.get("image_embeds")
1103
+ encoder_hidden_states = self.encoder_hid_proj(
1104
+ encoder_hidden_states, image_embeds
1105
+ )
1106
+ elif (
1107
+ self.encoder_hid_proj is not None
1108
+ and self.config.encoder_hid_dim_type == "image_proj"
1109
+ ):
1110
+ # Kandinsky 2.2 - style
1111
+ if "image_embeds" not in added_cond_kwargs:
1112
+ raise ValueError(
1113
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1114
+ )
1115
+ image_embeds = added_cond_kwargs.get("image_embeds")
1116
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1117
+ elif (
1118
+ self.encoder_hid_proj is not None
1119
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1120
+ ):
1121
+ if "image_embeds" not in added_cond_kwargs:
1122
+ raise ValueError(
1123
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1124
+ )
1125
+ image_embeds = added_cond_kwargs.get("image_embeds")
1126
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1127
+ encoder_hidden_states.dtype
1128
+ )
1129
+ encoder_hidden_states = torch.cat(
1130
+ [encoder_hidden_states, image_embeds], dim=1
1131
+ )
1132
+
1133
+ # 2. pre-process
1134
+ sample = self.conv_in(sample)
1135
+
1136
+ # 2.5 GLIGEN position net
1137
+ if (
1138
+ cross_attention_kwargs is not None
1139
+ and cross_attention_kwargs.get("gligen", None) is not None
1140
+ ):
1141
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1142
+ gligen_args = cross_attention_kwargs.pop("gligen")
1143
+ cross_attention_kwargs["gligen"] = {
1144
+ "objs": self.position_net(**gligen_args)
1145
+ }
1146
+
1147
+ # 3. down
1148
+ lora_scale = (
1149
+ cross_attention_kwargs.get("scale", 1.0)
1150
+ if cross_attention_kwargs is not None
1151
+ else 1.0
1152
+ )
1153
+ if USE_PEFT_BACKEND:
1154
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1155
+ scale_lora_layers(self, lora_scale)
1156
+
1157
+ is_controlnet = (
1158
+ mid_block_additional_residual is not None
1159
+ and down_block_additional_residuals is not None
1160
+ )
1161
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1162
+ is_adapter = down_intrablock_additional_residuals is not None
1163
+ # maintain backward compatibility for legacy usage, where
1164
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1165
+ # but can only use one or the other
1166
+ if (
1167
+ not is_adapter
1168
+ and mid_block_additional_residual is None
1169
+ and down_block_additional_residuals is not None
1170
+ ):
1171
+ deprecate(
1172
+ "T2I should not use down_block_additional_residuals",
1173
+ "1.3.0",
1174
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1175
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1176
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1177
+ standard_warn=False,
1178
+ )
1179
+ down_intrablock_additional_residuals = down_block_additional_residuals
1180
+ is_adapter = True
1181
+
1182
+ down_block_res_samples = (sample,)
1183
+ tot_referece_features = ()
1184
+ for downsample_block in self.down_blocks:
1185
+ if (
1186
+ hasattr(downsample_block, "has_cross_attention")
1187
+ and downsample_block.has_cross_attention
1188
+ ):
1189
+ # For t2i-adapter CrossAttnDownBlock2D
1190
+ additional_residuals = {}
1191
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1192
+ additional_residuals[
1193
+ "additional_residuals"
1194
+ ] = down_intrablock_additional_residuals.pop(0)
1195
+
1196
+ sample, res_samples = downsample_block(
1197
+ hidden_states=sample,
1198
+ temb=emb,
1199
+ encoder_hidden_states=encoder_hidden_states,
1200
+ attention_mask=attention_mask,
1201
+ cross_attention_kwargs=cross_attention_kwargs,
1202
+ encoder_attention_mask=encoder_attention_mask,
1203
+ **additional_residuals,
1204
+ )
1205
+ else:
1206
+ sample, res_samples = downsample_block(
1207
+ hidden_states=sample, temb=emb, scale=lora_scale
1208
+ )
1209
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1210
+ sample += down_intrablock_additional_residuals.pop(0)
1211
+
1212
+ down_block_res_samples += res_samples
1213
+
1214
+ if is_controlnet:
1215
+ new_down_block_res_samples = ()
1216
+
1217
+ for down_block_res_sample, down_block_additional_residual in zip(
1218
+ down_block_res_samples, down_block_additional_residuals
1219
+ ):
1220
+ down_block_res_sample = (
1221
+ down_block_res_sample + down_block_additional_residual
1222
+ )
1223
+ new_down_block_res_samples = new_down_block_res_samples + (
1224
+ down_block_res_sample,
1225
+ )
1226
+
1227
+ down_block_res_samples = new_down_block_res_samples
1228
+
1229
+ # 4. mid
1230
+ if self.mid_block is not None:
1231
+ if (
1232
+ hasattr(self.mid_block, "has_cross_attention")
1233
+ and self.mid_block.has_cross_attention
1234
+ ):
1235
+ sample = self.mid_block(
1236
+ sample,
1237
+ emb,
1238
+ encoder_hidden_states=encoder_hidden_states,
1239
+ attention_mask=attention_mask,
1240
+ cross_attention_kwargs=cross_attention_kwargs,
1241
+ encoder_attention_mask=encoder_attention_mask,
1242
+ )
1243
+ else:
1244
+ sample = self.mid_block(sample, emb)
1245
+
1246
+ # To support T2I-Adapter-XL
1247
+ if (
1248
+ is_adapter
1249
+ and len(down_intrablock_additional_residuals) > 0
1250
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1251
+ ):
1252
+ sample += down_intrablock_additional_residuals.pop(0)
1253
+
1254
+ if is_controlnet:
1255
+ sample = sample + mid_block_additional_residual
1256
+
1257
+ # 5. up
1258
+ for i, upsample_block in enumerate(self.up_blocks):
1259
+ is_final_block = i == len(self.up_blocks) - 1
1260
+
1261
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1262
+ down_block_res_samples = down_block_res_samples[
1263
+ : -len(upsample_block.resnets)
1264
+ ]
1265
+
1266
+ # if we have not reached the final block and need to forward the
1267
+ # upsample size, we do it here
1268
+ if not is_final_block and forward_upsample_size:
1269
+ upsample_size = down_block_res_samples[-1].shape[2:]
1270
+
1271
+ if (
1272
+ hasattr(upsample_block, "has_cross_attention")
1273
+ and upsample_block.has_cross_attention
1274
+ ):
1275
+ sample = upsample_block(
1276
+ hidden_states=sample,
1277
+ temb=emb,
1278
+ res_hidden_states_tuple=res_samples,
1279
+ encoder_hidden_states=encoder_hidden_states,
1280
+ cross_attention_kwargs=cross_attention_kwargs,
1281
+ upsample_size=upsample_size,
1282
+ attention_mask=attention_mask,
1283
+ encoder_attention_mask=encoder_attention_mask,
1284
+ )
1285
+ else:
1286
+ sample = upsample_block(
1287
+ hidden_states=sample,
1288
+ temb=emb,
1289
+ res_hidden_states_tuple=res_samples,
1290
+ upsample_size=upsample_size,
1291
+ scale=lora_scale,
1292
+ )
1293
+
1294
+ # 6. post-process
1295
+ # if self.conv_norm_out:
1296
+ # sample = self.conv_norm_out(sample)
1297
+ # sample = self.conv_act(sample)
1298
+ # sample = self.conv_out(sample)
1299
+
1300
+ if USE_PEFT_BACKEND:
1301
+ # remove `lora_scale` from each PEFT layer
1302
+ unscale_lora_layers(self, lora_scale)
1303
+
1304
+ if not return_dict:
1305
+ return (sample,)
1306
+
1307
+ return UNet2DConditionOutput(sample=sample)
musepose/models/unet_3d.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
17
+ from safetensors.torch import load_file
18
+
19
+ from .resnet import InflatedConv3d, InflatedGroupNorm
20
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ @dataclass
26
+ class UNet3DConditionOutput(BaseOutput):
27
+ sample: torch.FloatTensor
28
+
29
+
30
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
31
+ _supports_gradient_checkpointing = True
32
+
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ sample_size: Optional[int] = None,
37
+ in_channels: int = 4,
38
+ out_channels: int = 4,
39
+ center_input_sample: bool = False,
40
+ flip_sin_to_cos: bool = True,
41
+ freq_shift: int = 0,
42
+ down_block_types: Tuple[str] = (
43
+ "CrossAttnDownBlock3D",
44
+ "CrossAttnDownBlock3D",
45
+ "CrossAttnDownBlock3D",
46
+ "DownBlock3D",
47
+ ),
48
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
49
+ up_block_types: Tuple[str] = (
50
+ "UpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ "CrossAttnUpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ ),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ use_inflated_groupnorm=False,
72
+ # Additional
73
+ use_motion_module=False,
74
+ motion_module_resolutions=(1, 2, 4, 8),
75
+ motion_module_mid_block=False,
76
+ motion_module_decoder_only=False,
77
+ motion_module_type=None,
78
+ motion_module_kwargs={},
79
+ unet_use_cross_frame_attention=None,
80
+ unet_use_temporal_attention=None,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.sample_size = sample_size
85
+ time_embed_dim = block_out_channels[0] * 4
86
+
87
+ # input
88
+ self.conv_in = InflatedConv3d(
89
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
90
+ )
91
+
92
+ # time
93
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
94
+ timestep_input_dim = block_out_channels[0]
95
+
96
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97
+
98
+ # class embedding
99
+ if class_embed_type is None and num_class_embeds is not None:
100
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
101
+ elif class_embed_type == "timestep":
102
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+ elif class_embed_type == "identity":
104
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
105
+ else:
106
+ self.class_embedding = None
107
+
108
+ self.down_blocks = nn.ModuleList([])
109
+ self.mid_block = None
110
+ self.up_blocks = nn.ModuleList([])
111
+
112
+ if isinstance(only_cross_attention, bool):
113
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
114
+
115
+ if isinstance(attention_head_dim, int):
116
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
117
+
118
+ # down
119
+ output_channel = block_out_channels[0]
120
+ for i, down_block_type in enumerate(down_block_types):
121
+ res = 2**i
122
+ input_channel = output_channel
123
+ output_channel = block_out_channels[i]
124
+ is_final_block = i == len(block_out_channels) - 1
125
+
126
+ down_block = get_down_block(
127
+ down_block_type,
128
+ num_layers=layers_per_block,
129
+ in_channels=input_channel,
130
+ out_channels=output_channel,
131
+ temb_channels=time_embed_dim,
132
+ add_downsample=not is_final_block,
133
+ resnet_eps=norm_eps,
134
+ resnet_act_fn=act_fn,
135
+ resnet_groups=norm_num_groups,
136
+ cross_attention_dim=cross_attention_dim,
137
+ attn_num_head_channels=attention_head_dim[i],
138
+ downsample_padding=downsample_padding,
139
+ dual_cross_attention=dual_cross_attention,
140
+ use_linear_projection=use_linear_projection,
141
+ only_cross_attention=only_cross_attention[i],
142
+ upcast_attention=upcast_attention,
143
+ resnet_time_scale_shift=resnet_time_scale_shift,
144
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
145
+ unet_use_temporal_attention=unet_use_temporal_attention,
146
+ use_inflated_groupnorm=use_inflated_groupnorm,
147
+ use_motion_module=use_motion_module
148
+ and (res in motion_module_resolutions)
149
+ and (not motion_module_decoder_only),
150
+ motion_module_type=motion_module_type,
151
+ motion_module_kwargs=motion_module_kwargs,
152
+ )
153
+ self.down_blocks.append(down_block)
154
+
155
+ # mid
156
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
157
+ self.mid_block = UNetMidBlock3DCrossAttn(
158
+ in_channels=block_out_channels[-1],
159
+ temb_channels=time_embed_dim,
160
+ resnet_eps=norm_eps,
161
+ resnet_act_fn=act_fn,
162
+ output_scale_factor=mid_block_scale_factor,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ cross_attention_dim=cross_attention_dim,
165
+ attn_num_head_channels=attention_head_dim[-1],
166
+ resnet_groups=norm_num_groups,
167
+ dual_cross_attention=dual_cross_attention,
168
+ use_linear_projection=use_linear_projection,
169
+ upcast_attention=upcast_attention,
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+ use_inflated_groupnorm=use_inflated_groupnorm,
173
+ use_motion_module=use_motion_module and motion_module_mid_block,
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ else:
178
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
179
+
180
+ # count how many layers upsample the videos
181
+ self.num_upsamplers = 0
182
+
183
+ # up
184
+ reversed_block_out_channels = list(reversed(block_out_channels))
185
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
186
+ only_cross_attention = list(reversed(only_cross_attention))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ res = 2 ** (3 - i)
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[
195
+ min(i + 1, len(block_out_channels) - 1)
196
+ ]
197
+
198
+ # add upsample block for all BUT final layer
199
+ if not is_final_block:
200
+ add_upsample = True
201
+ self.num_upsamplers += 1
202
+ else:
203
+ add_upsample = False
204
+
205
+ up_block = get_up_block(
206
+ up_block_type,
207
+ num_layers=layers_per_block + 1,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ prev_output_channel=prev_output_channel,
211
+ temb_channels=time_embed_dim,
212
+ add_upsample=add_upsample,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ resnet_groups=norm_num_groups,
216
+ cross_attention_dim=cross_attention_dim,
217
+ attn_num_head_channels=reversed_attention_head_dim[i],
218
+ dual_cross_attention=dual_cross_attention,
219
+ use_linear_projection=use_linear_projection,
220
+ only_cross_attention=only_cross_attention[i],
221
+ upcast_attention=upcast_attention,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
224
+ unet_use_temporal_attention=unet_use_temporal_attention,
225
+ use_inflated_groupnorm=use_inflated_groupnorm,
226
+ use_motion_module=use_motion_module
227
+ and (res in motion_module_resolutions),
228
+ motion_module_type=motion_module_type,
229
+ motion_module_kwargs=motion_module_kwargs,
230
+ )
231
+ self.up_blocks.append(up_block)
232
+ prev_output_channel = output_channel
233
+
234
+ # out
235
+ if use_inflated_groupnorm:
236
+ self.conv_norm_out = InflatedGroupNorm(
237
+ num_channels=block_out_channels[0],
238
+ num_groups=norm_num_groups,
239
+ eps=norm_eps,
240
+ )
241
+ else:
242
+ self.conv_norm_out = nn.GroupNorm(
243
+ num_channels=block_out_channels[0],
244
+ num_groups=norm_num_groups,
245
+ eps=norm_eps,
246
+ )
247
+ self.conv_act = nn.SiLU()
248
+ self.conv_out = InflatedConv3d(
249
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
250
+ )
251
+
252
+ @property
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
254
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
255
+ r"""
256
+ Returns:
257
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
258
+ indexed by its weight name.
259
+ """
260
+ # set recursively
261
+ processors = {}
262
+
263
+ def fn_recursive_add_processors(
264
+ name: str,
265
+ module: torch.nn.Module,
266
+ processors: Dict[str, AttentionProcessor],
267
+ ):
268
+ if hasattr(module, "set_processor"):
269
+ processors[f"{name}.processor"] = module.processor
270
+
271
+ for sub_name, child in module.named_children():
272
+ if "temporal_transformer" not in sub_name:
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ if "temporal_transformer" not in name:
279
+ fn_recursive_add_processors(name, module, processors)
280
+
281
+ return processors
282
+
283
+ def set_attention_slice(self, slice_size):
284
+ r"""
285
+ Enable sliced attention computation.
286
+
287
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
288
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
289
+
290
+ Args:
291
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
292
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
293
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
294
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
295
+ must be a multiple of `slice_size`.
296
+ """
297
+ sliceable_head_dims = []
298
+
299
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
300
+ if hasattr(module, "set_attention_slice"):
301
+ sliceable_head_dims.append(module.sliceable_head_dim)
302
+
303
+ for child in module.children():
304
+ fn_recursive_retrieve_slicable_dims(child)
305
+
306
+ # retrieve number of attention layers
307
+ for module in self.children():
308
+ fn_recursive_retrieve_slicable_dims(module)
309
+
310
+ num_slicable_layers = len(sliceable_head_dims)
311
+
312
+ if slice_size == "auto":
313
+ # half the attention head size is usually a good trade-off between
314
+ # speed and memory
315
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
316
+ elif slice_size == "max":
317
+ # make smallest slice possible
318
+ slice_size = num_slicable_layers * [1]
319
+
320
+ slice_size = (
321
+ num_slicable_layers * [slice_size]
322
+ if not isinstance(slice_size, list)
323
+ else slice_size
324
+ )
325
+
326
+ if len(slice_size) != len(sliceable_head_dims):
327
+ raise ValueError(
328
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
329
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
330
+ )
331
+
332
+ for i in range(len(slice_size)):
333
+ size = slice_size[i]
334
+ dim = sliceable_head_dims[i]
335
+ if size is not None and size > dim:
336
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
337
+
338
+ # Recursively walk through all the children.
339
+ # Any children which exposes the set_attention_slice method
340
+ # gets the message
341
+ def fn_recursive_set_attention_slice(
342
+ module: torch.nn.Module, slice_size: List[int]
343
+ ):
344
+ if hasattr(module, "set_attention_slice"):
345
+ module.set_attention_slice(slice_size.pop())
346
+
347
+ for child in module.children():
348
+ fn_recursive_set_attention_slice(child, slice_size)
349
+
350
+ reversed_slice_size = list(reversed(slice_size))
351
+ for module in self.children():
352
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
353
+
354
+ def _set_gradient_checkpointing(self, module, value=False):
355
+ if hasattr(module, "gradient_checkpointing"):
356
+ module.gradient_checkpointing = value
357
+
358
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
359
+ def set_attn_processor(
360
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
361
+ ):
362
+ r"""
363
+ Sets the attention processor to use to compute attention.
364
+
365
+ Parameters:
366
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
367
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
368
+ for **all** `Attention` layers.
369
+
370
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
371
+ processor. This is strongly recommended when setting trainable attention processors.
372
+
373
+ """
374
+ count = len(self.attn_processors.keys())
375
+
376
+ if isinstance(processor, dict) and len(processor) != count:
377
+ raise ValueError(
378
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
379
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
380
+ )
381
+
382
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
383
+ if hasattr(module, "set_processor"):
384
+ if not isinstance(processor, dict):
385
+ module.set_processor(processor)
386
+ else:
387
+ module.set_processor(processor.pop(f"{name}.processor"))
388
+
389
+ for sub_name, child in module.named_children():
390
+ if "temporal_transformer" not in sub_name:
391
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
392
+
393
+ for name, module in self.named_children():
394
+ if "temporal_transformer" not in name:
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def forward(
398
+ self,
399
+ sample: torch.FloatTensor,
400
+ timestep: Union[torch.Tensor, float, int],
401
+ encoder_hidden_states: torch.Tensor,
402
+ class_labels: Optional[torch.Tensor] = None,
403
+ pose_cond_fea: Optional[torch.Tensor] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
406
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
407
+ return_dict: bool = True,
408
+ ) -> Union[UNet3DConditionOutput, Tuple]:
409
+ r"""
410
+ Args:
411
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
412
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
413
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
414
+ return_dict (`bool`, *optional*, defaults to `True`):
415
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
416
+
417
+ Returns:
418
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
419
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
420
+ returning a tuple, the first element is the sample tensor.
421
+ """
422
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
423
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
424
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
425
+ # on the fly if necessary.
426
+ default_overall_up_factor = 2**self.num_upsamplers
427
+
428
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
429
+ forward_upsample_size = False
430
+ upsample_size = None
431
+
432
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
433
+ logger.info("Forward upsample size to force interpolation output size.")
434
+ forward_upsample_size = True
435
+
436
+ # prepare attention_mask
437
+ if attention_mask is not None:
438
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
439
+ attention_mask = attention_mask.unsqueeze(1)
440
+
441
+ # center input if necessary
442
+ if self.config.center_input_sample:
443
+ sample = 2 * sample - 1.0
444
+
445
+ # time
446
+ timesteps = timestep
447
+ if not torch.is_tensor(timesteps):
448
+ # This would be a good case for the `match` statement (Python 3.10+)
449
+ is_mps = sample.device.type == "mps"
450
+ if isinstance(timestep, float):
451
+ dtype = torch.float32 if is_mps else torch.float64
452
+ else:
453
+ dtype = torch.int32 if is_mps else torch.int64
454
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
455
+ elif len(timesteps.shape) == 0:
456
+ timesteps = timesteps[None].to(sample.device)
457
+
458
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
459
+ timesteps = timesteps.expand(sample.shape[0])
460
+
461
+ t_emb = self.time_proj(timesteps)
462
+
463
+ # timesteps does not contain any weights and will always return f32 tensors
464
+ # but time_embedding might actually be running in fp16. so we need to cast here.
465
+ # there might be better ways to encapsulate this.
466
+ t_emb = t_emb.to(dtype=self.dtype)
467
+ emb = self.time_embedding(t_emb)
468
+
469
+ if self.class_embedding is not None:
470
+ if class_labels is None:
471
+ raise ValueError(
472
+ "class_labels should be provided when num_class_embeds > 0"
473
+ )
474
+
475
+ if self.config.class_embed_type == "timestep":
476
+ class_labels = self.time_proj(class_labels)
477
+
478
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
479
+ emb = emb + class_emb
480
+
481
+ # pre-process
482
+ sample = self.conv_in(sample)
483
+ if pose_cond_fea is not None:
484
+ sample = sample + pose_cond_fea
485
+
486
+ # down
487
+ down_block_res_samples = (sample,)
488
+ for downsample_block in self.down_blocks:
489
+ if (
490
+ hasattr(downsample_block, "has_cross_attention")
491
+ and downsample_block.has_cross_attention
492
+ ):
493
+ sample, res_samples = downsample_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ encoder_hidden_states=encoder_hidden_states,
497
+ attention_mask=attention_mask,
498
+ )
499
+ else:
500
+ sample, res_samples = downsample_block(
501
+ hidden_states=sample,
502
+ temb=emb,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ )
505
+
506
+ down_block_res_samples += res_samples
507
+
508
+ if down_block_additional_residuals is not None:
509
+ new_down_block_res_samples = ()
510
+
511
+ for down_block_res_sample, down_block_additional_residual in zip(
512
+ down_block_res_samples, down_block_additional_residuals
513
+ ):
514
+ down_block_res_sample = (
515
+ down_block_res_sample + down_block_additional_residual
516
+ )
517
+ new_down_block_res_samples += (down_block_res_sample,)
518
+
519
+ down_block_res_samples = new_down_block_res_samples
520
+
521
+ # mid
522
+ sample = self.mid_block(
523
+ sample,
524
+ emb,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ )
528
+
529
+ if mid_block_additional_residual is not None:
530
+ sample = sample + mid_block_additional_residual
531
+
532
+ # up
533
+ for i, upsample_block in enumerate(self.up_blocks):
534
+ is_final_block = i == len(self.up_blocks) - 1
535
+
536
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
537
+ down_block_res_samples = down_block_res_samples[
538
+ : -len(upsample_block.resnets)
539
+ ]
540
+
541
+ # if we have not reached the final block and need to forward the
542
+ # upsample size, we do it here
543
+ if not is_final_block and forward_upsample_size:
544
+ upsample_size = down_block_res_samples[-1].shape[2:]
545
+
546
+ if (
547
+ hasattr(upsample_block, "has_cross_attention")
548
+ and upsample_block.has_cross_attention
549
+ ):
550
+ sample = upsample_block(
551
+ hidden_states=sample,
552
+ temb=emb,
553
+ res_hidden_states_tuple=res_samples,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ upsample_size=upsample_size,
556
+ attention_mask=attention_mask,
557
+ )
558
+ else:
559
+ sample = upsample_block(
560
+ hidden_states=sample,
561
+ temb=emb,
562
+ res_hidden_states_tuple=res_samples,
563
+ upsample_size=upsample_size,
564
+ encoder_hidden_states=encoder_hidden_states,
565
+ )
566
+
567
+ # post-process
568
+ sample = self.conv_norm_out(sample)
569
+ sample = self.conv_act(sample)
570
+ sample = self.conv_out(sample)
571
+
572
+ if not return_dict:
573
+ return (sample,)
574
+
575
+ return UNet3DConditionOutput(sample=sample)
576
+
577
+ @classmethod
578
+ def from_pretrained_2d(
579
+ cls,
580
+ pretrained_model_path: PathLike,
581
+ motion_module_path: PathLike,
582
+ subfolder=None,
583
+ unet_additional_kwargs=None,
584
+ mm_zero_proj_out=False,
585
+ ):
586
+ pretrained_model_path = Path(pretrained_model_path)
587
+ motion_module_path = Path(motion_module_path)
588
+ if subfolder is not None:
589
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
590
+ logger.info(
591
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
592
+ )
593
+
594
+ config_file = pretrained_model_path / "config.json"
595
+ if not (config_file.exists() and config_file.is_file()):
596
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
597
+
598
+ unet_config = cls.load_config(config_file)
599
+ unet_config["_class_name"] = cls.__name__
600
+ unet_config["down_block_types"] = [
601
+ "CrossAttnDownBlock3D",
602
+ "CrossAttnDownBlock3D",
603
+ "CrossAttnDownBlock3D",
604
+ "DownBlock3D",
605
+ ]
606
+ unet_config["up_block_types"] = [
607
+ "UpBlock3D",
608
+ "CrossAttnUpBlock3D",
609
+ "CrossAttnUpBlock3D",
610
+ "CrossAttnUpBlock3D",
611
+ ]
612
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
613
+
614
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
615
+ # load the vanilla weights
616
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
617
+ logger.debug(
618
+ f"loading safeTensors weights from {pretrained_model_path} ..."
619
+ )
620
+ state_dict = load_file(
621
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
622
+ )
623
+
624
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
625
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
626
+ state_dict = torch.load(
627
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
628
+ map_location="cpu",
629
+ weights_only=True,
630
+ )
631
+ else:
632
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
633
+
634
+ # load the motion module weights
635
+ if motion_module_path.exists() and motion_module_path.is_file():
636
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
637
+ logger.info(f"Load motion module params from {motion_module_path}")
638
+ motion_state_dict = torch.load(
639
+ motion_module_path, map_location="cpu", weights_only=True
640
+ )
641
+ elif motion_module_path.suffix.lower() == ".safetensors":
642
+ motion_state_dict = load_file(motion_module_path, device="cpu")
643
+ else:
644
+ raise RuntimeError(
645
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
646
+ )
647
+ if mm_zero_proj_out:
648
+ logger.info(f"Zero initialize proj_out layers in motion module...")
649
+ new_motion_state_dict = OrderedDict()
650
+ for k in motion_state_dict:
651
+ if "proj_out" in k:
652
+ continue
653
+ new_motion_state_dict[k] = motion_state_dict[k]
654
+ motion_state_dict = new_motion_state_dict
655
+
656
+
657
+
658
+ for weight_name in list(motion_state_dict.keys()):
659
+ if weight_name[-2:]== 'pe':
660
+ del motion_state_dict[weight_name]
661
+ # print(weight_name)
662
+
663
+ # merge the state dicts
664
+ state_dict.update(motion_state_dict)
665
+
666
+ # load the weights into the model
667
+ m, u = model.load_state_dict(state_dict, strict=False)
668
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
669
+
670
+ params = [
671
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
672
+ ]
673
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
674
+
675
+ return model
musepose/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ if motion_module is not None:
433
+ hidden_states = torch.utils.checkpoint.checkpoint(
434
+ create_custom_forward(motion_module),
435
+ hidden_states.requires_grad_(),
436
+ temb,
437
+ encoder_hidden_states,
438
+ )
439
+
440
+ # # add motion module
441
+ # hidden_states = (
442
+ # motion_module(
443
+ # hidden_states, temb, encoder_hidden_states=encoder_hidden_states
444
+ # )
445
+ # if motion_module is not None
446
+ # else hidden_states
447
+ # )
448
+
449
+ else:
450
+ hidden_states = resnet(hidden_states, temb)
451
+ hidden_states = attn(
452
+ hidden_states,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ ).sample
455
+
456
+ # add motion module
457
+ hidden_states = (
458
+ motion_module(
459
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
460
+ )
461
+ if motion_module is not None
462
+ else hidden_states
463
+ )
464
+
465
+ output_states += (hidden_states,)
466
+
467
+ if self.downsamplers is not None:
468
+ for downsampler in self.downsamplers:
469
+ hidden_states = downsampler(hidden_states)
470
+
471
+ output_states += (hidden_states,)
472
+
473
+ return hidden_states, output_states
474
+
475
+
476
+ class DownBlock3D(nn.Module):
477
+ def __init__(
478
+ self,
479
+ in_channels: int,
480
+ out_channels: int,
481
+ temb_channels: int,
482
+ dropout: float = 0.0,
483
+ num_layers: int = 1,
484
+ resnet_eps: float = 1e-6,
485
+ resnet_time_scale_shift: str = "default",
486
+ resnet_act_fn: str = "swish",
487
+ resnet_groups: int = 32,
488
+ resnet_pre_norm: bool = True,
489
+ output_scale_factor=1.0,
490
+ add_downsample=True,
491
+ downsample_padding=1,
492
+ use_inflated_groupnorm=None,
493
+ use_motion_module=None,
494
+ motion_module_type=None,
495
+ motion_module_kwargs=None,
496
+ ):
497
+ super().__init__()
498
+ resnets = []
499
+ motion_modules = []
500
+
501
+ # use_motion_module = False
502
+ for i in range(num_layers):
503
+ in_channels = in_channels if i == 0 else out_channels
504
+ resnets.append(
505
+ ResnetBlock3D(
506
+ in_channels=in_channels,
507
+ out_channels=out_channels,
508
+ temb_channels=temb_channels,
509
+ eps=resnet_eps,
510
+ groups=resnet_groups,
511
+ dropout=dropout,
512
+ time_embedding_norm=resnet_time_scale_shift,
513
+ non_linearity=resnet_act_fn,
514
+ output_scale_factor=output_scale_factor,
515
+ pre_norm=resnet_pre_norm,
516
+ use_inflated_groupnorm=use_inflated_groupnorm,
517
+ )
518
+ )
519
+ motion_modules.append(
520
+ get_motion_module(
521
+ in_channels=out_channels,
522
+ motion_module_type=motion_module_type,
523
+ motion_module_kwargs=motion_module_kwargs,
524
+ )
525
+ if use_motion_module
526
+ else None
527
+ )
528
+
529
+ self.resnets = nn.ModuleList(resnets)
530
+ self.motion_modules = nn.ModuleList(motion_modules)
531
+
532
+ if add_downsample:
533
+ self.downsamplers = nn.ModuleList(
534
+ [
535
+ Downsample3D(
536
+ out_channels,
537
+ use_conv=True,
538
+ out_channels=out_channels,
539
+ padding=downsample_padding,
540
+ name="op",
541
+ )
542
+ ]
543
+ )
544
+ else:
545
+ self.downsamplers = None
546
+
547
+ self.gradient_checkpointing = False
548
+
549
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
550
+ output_states = ()
551
+
552
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
553
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
554
+ if self.training and self.gradient_checkpointing:
555
+
556
+ def create_custom_forward(module):
557
+ def custom_forward(*inputs):
558
+ return module(*inputs)
559
+
560
+ return custom_forward
561
+
562
+ hidden_states = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(resnet), hidden_states, temb
564
+ )
565
+ if motion_module is not None:
566
+ hidden_states = torch.utils.checkpoint.checkpoint(
567
+ create_custom_forward(motion_module),
568
+ hidden_states.requires_grad_(),
569
+ temb,
570
+ encoder_hidden_states,
571
+ )
572
+ else:
573
+ hidden_states = resnet(hidden_states, temb)
574
+
575
+ # add motion module
576
+ hidden_states = (
577
+ motion_module(
578
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
579
+ )
580
+ if motion_module is not None
581
+ else hidden_states
582
+ )
583
+
584
+ output_states += (hidden_states,)
585
+
586
+ if self.downsamplers is not None:
587
+ for downsampler in self.downsamplers:
588
+ hidden_states = downsampler(hidden_states)
589
+
590
+ output_states += (hidden_states,)
591
+
592
+ return hidden_states, output_states
593
+
594
+
595
+ class CrossAttnUpBlock3D(nn.Module):
596
+ def __init__(
597
+ self,
598
+ in_channels: int,
599
+ out_channels: int,
600
+ prev_output_channel: int,
601
+ temb_channels: int,
602
+ dropout: float = 0.0,
603
+ num_layers: int = 1,
604
+ resnet_eps: float = 1e-6,
605
+ resnet_time_scale_shift: str = "default",
606
+ resnet_act_fn: str = "swish",
607
+ resnet_groups: int = 32,
608
+ resnet_pre_norm: bool = True,
609
+ attn_num_head_channels=1,
610
+ cross_attention_dim=1280,
611
+ output_scale_factor=1.0,
612
+ add_upsample=True,
613
+ dual_cross_attention=False,
614
+ use_linear_projection=False,
615
+ only_cross_attention=False,
616
+ upcast_attention=False,
617
+ unet_use_cross_frame_attention=None,
618
+ unet_use_temporal_attention=None,
619
+ use_motion_module=None,
620
+ use_inflated_groupnorm=None,
621
+ motion_module_type=None,
622
+ motion_module_kwargs=None,
623
+ ):
624
+ super().__init__()
625
+ resnets = []
626
+ attentions = []
627
+ motion_modules = []
628
+
629
+ self.has_cross_attention = True
630
+ self.attn_num_head_channels = attn_num_head_channels
631
+
632
+ for i in range(num_layers):
633
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
634
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
635
+
636
+ resnets.append(
637
+ ResnetBlock3D(
638
+ in_channels=resnet_in_channels + res_skip_channels,
639
+ out_channels=out_channels,
640
+ temb_channels=temb_channels,
641
+ eps=resnet_eps,
642
+ groups=resnet_groups,
643
+ dropout=dropout,
644
+ time_embedding_norm=resnet_time_scale_shift,
645
+ non_linearity=resnet_act_fn,
646
+ output_scale_factor=output_scale_factor,
647
+ pre_norm=resnet_pre_norm,
648
+ use_inflated_groupnorm=use_inflated_groupnorm,
649
+ )
650
+ )
651
+ if dual_cross_attention:
652
+ raise NotImplementedError
653
+ attentions.append(
654
+ Transformer3DModel(
655
+ attn_num_head_channels,
656
+ out_channels // attn_num_head_channels,
657
+ in_channels=out_channels,
658
+ num_layers=1,
659
+ cross_attention_dim=cross_attention_dim,
660
+ norm_num_groups=resnet_groups,
661
+ use_linear_projection=use_linear_projection,
662
+ only_cross_attention=only_cross_attention,
663
+ upcast_attention=upcast_attention,
664
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
665
+ unet_use_temporal_attention=unet_use_temporal_attention,
666
+ )
667
+ )
668
+ motion_modules.append(
669
+ get_motion_module(
670
+ in_channels=out_channels,
671
+ motion_module_type=motion_module_type,
672
+ motion_module_kwargs=motion_module_kwargs,
673
+ )
674
+ if use_motion_module
675
+ else None
676
+ )
677
+
678
+ self.attentions = nn.ModuleList(attentions)
679
+ self.resnets = nn.ModuleList(resnets)
680
+ self.motion_modules = nn.ModuleList(motion_modules)
681
+
682
+ if add_upsample:
683
+ self.upsamplers = nn.ModuleList(
684
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
685
+ )
686
+ else:
687
+ self.upsamplers = None
688
+
689
+ self.gradient_checkpointing = False
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states,
694
+ res_hidden_states_tuple,
695
+ temb=None,
696
+ encoder_hidden_states=None,
697
+ upsample_size=None,
698
+ attention_mask=None,
699
+ ):
700
+ for i, (resnet, attn, motion_module) in enumerate(
701
+ zip(self.resnets, self.attentions, self.motion_modules)
702
+ ):
703
+ # pop res hidden states
704
+ res_hidden_states = res_hidden_states_tuple[-1]
705
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
706
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
707
+
708
+ if self.training and self.gradient_checkpointing:
709
+
710
+ def create_custom_forward(module, return_dict=None):
711
+ def custom_forward(*inputs):
712
+ if return_dict is not None:
713
+ return module(*inputs, return_dict=return_dict)
714
+ else:
715
+ return module(*inputs)
716
+
717
+ return custom_forward
718
+
719
+ hidden_states = torch.utils.checkpoint.checkpoint(
720
+ create_custom_forward(resnet), hidden_states, temb
721
+ )
722
+ hidden_states = attn(
723
+ hidden_states,
724
+ encoder_hidden_states=encoder_hidden_states,
725
+ ).sample
726
+ if motion_module is not None:
727
+ hidden_states = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(motion_module),
729
+ hidden_states.requires_grad_(),
730
+ temb,
731
+ encoder_hidden_states,
732
+ )
733
+
734
+ else:
735
+ hidden_states = resnet(hidden_states, temb)
736
+ hidden_states = attn(
737
+ hidden_states,
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ ).sample
740
+
741
+ # add motion module
742
+ hidden_states = (
743
+ motion_module(
744
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
745
+ )
746
+ if motion_module is not None
747
+ else hidden_states
748
+ )
749
+
750
+ if self.upsamplers is not None:
751
+ for upsampler in self.upsamplers:
752
+ hidden_states = upsampler(hidden_states, upsample_size)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class UpBlock3D(nn.Module):
758
+ def __init__(
759
+ self,
760
+ in_channels: int,
761
+ prev_output_channel: int,
762
+ out_channels: int,
763
+ temb_channels: int,
764
+ dropout: float = 0.0,
765
+ num_layers: int = 1,
766
+ resnet_eps: float = 1e-6,
767
+ resnet_time_scale_shift: str = "default",
768
+ resnet_act_fn: str = "swish",
769
+ resnet_groups: int = 32,
770
+ resnet_pre_norm: bool = True,
771
+ output_scale_factor=1.0,
772
+ add_upsample=True,
773
+ use_inflated_groupnorm=None,
774
+ use_motion_module=None,
775
+ motion_module_type=None,
776
+ motion_module_kwargs=None,
777
+ ):
778
+ super().__init__()
779
+ resnets = []
780
+ motion_modules = []
781
+
782
+ # use_motion_module = False
783
+ for i in range(num_layers):
784
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
785
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
786
+
787
+ resnets.append(
788
+ ResnetBlock3D(
789
+ in_channels=resnet_in_channels + res_skip_channels,
790
+ out_channels=out_channels,
791
+ temb_channels=temb_channels,
792
+ eps=resnet_eps,
793
+ groups=resnet_groups,
794
+ dropout=dropout,
795
+ time_embedding_norm=resnet_time_scale_shift,
796
+ non_linearity=resnet_act_fn,
797
+ output_scale_factor=output_scale_factor,
798
+ pre_norm=resnet_pre_norm,
799
+ use_inflated_groupnorm=use_inflated_groupnorm,
800
+ )
801
+ )
802
+ motion_modules.append(
803
+ get_motion_module(
804
+ in_channels=out_channels,
805
+ motion_module_type=motion_module_type,
806
+ motion_module_kwargs=motion_module_kwargs,
807
+ )
808
+ if use_motion_module
809
+ else None
810
+ )
811
+
812
+ self.resnets = nn.ModuleList(resnets)
813
+ self.motion_modules = nn.ModuleList(motion_modules)
814
+
815
+ if add_upsample:
816
+ self.upsamplers = nn.ModuleList(
817
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
818
+ )
819
+ else:
820
+ self.upsamplers = None
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(
825
+ self,
826
+ hidden_states,
827
+ res_hidden_states_tuple,
828
+ temb=None,
829
+ upsample_size=None,
830
+ encoder_hidden_states=None,
831
+ ):
832
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
833
+ # pop res hidden states
834
+ res_hidden_states = res_hidden_states_tuple[-1]
835
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
836
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
837
+
838
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
839
+ if self.training and self.gradient_checkpointing:
840
+
841
+ def create_custom_forward(module):
842
+ def custom_forward(*inputs):
843
+ return module(*inputs)
844
+
845
+ return custom_forward
846
+
847
+ hidden_states = torch.utils.checkpoint.checkpoint(
848
+ create_custom_forward(resnet), hidden_states, temb
849
+ )
850
+ if motion_module is not None:
851
+ hidden_states = torch.utils.checkpoint.checkpoint(
852
+ create_custom_forward(motion_module),
853
+ hidden_states.requires_grad_(),
854
+ temb,
855
+ encoder_hidden_states,
856
+ )
857
+ else:
858
+ hidden_states = resnet(hidden_states, temb)
859
+ hidden_states = (
860
+ motion_module(
861
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
862
+ )
863
+ if motion_module is not None
864
+ else hidden_states
865
+ )
866
+
867
+ if self.upsamplers is not None:
868
+ for upsampler in self.upsamplers:
869
+ hidden_states = upsampler(hidden_states, upsample_size)
870
+
871
+ return hidden_states
musepose/pipelines/__init__.py ADDED
File without changes
musepose/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = False,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
musepose/pipelines/pipeline_pose2img.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+ from diffusers.utils import BaseOutput, is_accelerate_available
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
24
+
25
+
26
+ @dataclass
27
+ class Pose2ImagePipelineOutput(BaseOutput):
28
+ images: Union[torch.Tensor, np.ndarray]
29
+
30
+
31
+ class Pose2ImagePipeline(DiffusionPipeline):
32
+ _optional_components = []
33
+
34
+ def __init__(
35
+ self,
36
+ vae,
37
+ image_encoder,
38
+ reference_unet,
39
+ denoising_unet,
40
+ pose_guider,
41
+ scheduler: Union[
42
+ DDIMScheduler,
43
+ PNDMScheduler,
44
+ LMSDiscreteScheduler,
45
+ EulerDiscreteScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ DPMSolverMultistepScheduler,
48
+ ],
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ image_encoder=image_encoder,
55
+ reference_unet=reference_unet,
56
+ denoising_unet=denoising_unet,
57
+ pose_guider=pose_guider,
58
+ scheduler=scheduler,
59
+ )
60
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61
+ self.clip_image_processor = CLIPImageProcessor()
62
+ self.ref_image_processor = VaeImageProcessor(
63
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64
+ )
65
+ self.cond_image_processor = VaeImageProcessor(
66
+ vae_scale_factor=self.vae_scale_factor,
67
+ do_convert_rgb=True,
68
+ do_normalize=False,
69
+ )
70
+
71
+ def enable_vae_slicing(self):
72
+ self.vae.enable_slicing()
73
+
74
+ def disable_vae_slicing(self):
75
+ self.vae.disable_slicing()
76
+
77
+ def enable_sequential_cpu_offload(self, gpu_id=0):
78
+ if is_accelerate_available():
79
+ from accelerate import cpu_offload
80
+ else:
81
+ raise ImportError("Please install accelerate via `pip install accelerate`")
82
+
83
+ device = torch.device(f"cuda:{gpu_id}")
84
+
85
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86
+ if cpu_offloaded_model is not None:
87
+ cpu_offload(cpu_offloaded_model, device)
88
+
89
+ @property
90
+ def _execution_device(self):
91
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92
+ return self.device
93
+ for module in self.unet.modules():
94
+ if (
95
+ hasattr(module, "_hf_hook")
96
+ and hasattr(module._hf_hook, "execution_device")
97
+ and module._hf_hook.execution_device is not None
98
+ ):
99
+ return torch.device(module._hf_hook.execution_device)
100
+ return self.device
101
+
102
+ def decode_latents(self, latents):
103
+ video_length = latents.shape[2]
104
+ latents = 1 / 0.18215 * latents
105
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
106
+ # video = self.vae.decode(latents).sample
107
+ video = []
108
+ for frame_idx in tqdm(range(latents.shape[0])):
109
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110
+ video = torch.cat(video)
111
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112
+ video = (video / 2 + 0.5).clamp(0, 1)
113
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114
+ video = video.cpu().float().numpy()
115
+ return video
116
+
117
+ def prepare_extra_step_kwargs(self, generator, eta):
118
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121
+ # and should be between [0, 1]
122
+
123
+ accepts_eta = "eta" in set(
124
+ inspect.signature(self.scheduler.step).parameters.keys()
125
+ )
126
+ extra_step_kwargs = {}
127
+ if accepts_eta:
128
+ extra_step_kwargs["eta"] = eta
129
+
130
+ # check if the scheduler accepts generator
131
+ accepts_generator = "generator" in set(
132
+ inspect.signature(self.scheduler.step).parameters.keys()
133
+ )
134
+ if accepts_generator:
135
+ extra_step_kwargs["generator"] = generator
136
+ return extra_step_kwargs
137
+
138
+ def prepare_latents(
139
+ self,
140
+ batch_size,
141
+ num_channels_latents,
142
+ width,
143
+ height,
144
+ dtype,
145
+ device,
146
+ generator,
147
+ latents=None,
148
+ ):
149
+ shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ height // self.vae_scale_factor,
153
+ width // self.vae_scale_factor,
154
+ )
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ if latents is None:
162
+ latents = randn_tensor(
163
+ shape, generator=generator, device=device, dtype=dtype
164
+ )
165
+ else:
166
+ latents = latents.to(device)
167
+
168
+ # scale the initial noise by the standard deviation required by the scheduler
169
+ latents = latents * self.scheduler.init_noise_sigma
170
+ return latents
171
+
172
+ def prepare_condition(
173
+ self,
174
+ cond_image,
175
+ width,
176
+ height,
177
+ device,
178
+ dtype,
179
+ do_classififer_free_guidance=False,
180
+ ):
181
+ image = self.cond_image_processor.preprocess(
182
+ cond_image, height=height, width=width
183
+ ).to(dtype=torch.float32)
184
+
185
+ image = image.to(device=device, dtype=dtype)
186
+
187
+ if do_classififer_free_guidance:
188
+ image = torch.cat([image] * 2)
189
+
190
+ return image
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ ref_image,
196
+ pose_image,
197
+ width,
198
+ height,
199
+ num_inference_steps,
200
+ guidance_scale,
201
+ num_images_per_prompt=1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ output_type: Optional[str] = "tensor",
205
+ return_dict: bool = True,
206
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207
+ callback_steps: Optional[int] = 1,
208
+ **kwargs,
209
+ ):
210
+ # Default height and width to unet
211
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
212
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
213
+
214
+ device = self._execution_device
215
+
216
+ do_classifier_free_guidance = guidance_scale > 1.0
217
+
218
+ # Prepare timesteps
219
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
220
+ timesteps = self.scheduler.timesteps
221
+
222
+ batch_size = 1
223
+
224
+ # Prepare clip image embeds
225
+ clip_image = self.clip_image_processor.preprocess(
226
+ ref_image.resize((224, 224)), return_tensors="pt"
227
+ ).pixel_values
228
+ clip_image_embeds = self.image_encoder(
229
+ clip_image.to(device, dtype=self.image_encoder.dtype)
230
+ ).image_embeds
231
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233
+
234
+ if do_classifier_free_guidance:
235
+ image_prompt_embeds = torch.cat(
236
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237
+ )
238
+
239
+ reference_control_writer = ReferenceAttentionControl(
240
+ self.reference_unet,
241
+ do_classifier_free_guidance=do_classifier_free_guidance,
242
+ mode="write",
243
+ batch_size=batch_size,
244
+ fusion_blocks="full",
245
+ )
246
+ reference_control_reader = ReferenceAttentionControl(
247
+ self.denoising_unet,
248
+ do_classifier_free_guidance=do_classifier_free_guidance,
249
+ mode="read",
250
+ batch_size=batch_size,
251
+ fusion_blocks="full",
252
+ )
253
+
254
+ num_channels_latents = self.denoising_unet.in_channels
255
+ latents = self.prepare_latents(
256
+ batch_size * num_images_per_prompt,
257
+ num_channels_latents,
258
+ width,
259
+ height,
260
+ clip_image_embeds.dtype,
261
+ device,
262
+ generator,
263
+ )
264
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265
+ latents_dtype = latents.dtype
266
+
267
+ # Prepare extra step kwargs.
268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269
+
270
+ # Prepare ref image latents
271
+ ref_image_tensor = self.ref_image_processor.preprocess(
272
+ ref_image, height=height, width=width
273
+ ) # (bs, c, width, height)
274
+ ref_image_tensor = ref_image_tensor.to(
275
+ dtype=self.vae.dtype, device=self.vae.device
276
+ )
277
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279
+
280
+ # Prepare pose condition image
281
+ pose_cond_tensor = self.cond_image_processor.preprocess(
282
+ pose_image, height=height, width=width
283
+ )
284
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285
+ pose_cond_tensor = pose_cond_tensor.to(
286
+ device=device, dtype=self.pose_guider.dtype
287
+ )
288
+ pose_fea = self.pose_guider(pose_cond_tensor)
289
+ pose_fea = (
290
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291
+ )
292
+
293
+ # denoising loop
294
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
296
+ for i, t in enumerate(timesteps):
297
+ # 1. Forward reference image
298
+ if i == 0:
299
+ self.reference_unet(
300
+ ref_image_latents.repeat(
301
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
302
+ ),
303
+ torch.zeros_like(t),
304
+ encoder_hidden_states=image_prompt_embeds,
305
+ return_dict=False,
306
+ )
307
+
308
+ # 2. Update reference unet feature into denosing net
309
+ reference_control_reader.update(reference_control_writer)
310
+
311
+ # 3.1 expand the latents if we are doing classifier free guidance
312
+ latent_model_input = (
313
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314
+ )
315
+ latent_model_input = self.scheduler.scale_model_input(
316
+ latent_model_input, t
317
+ )
318
+
319
+ noise_pred = self.denoising_unet(
320
+ latent_model_input,
321
+ t,
322
+ encoder_hidden_states=image_prompt_embeds,
323
+ pose_cond_fea=pose_fea,
324
+ return_dict=False,
325
+ )[0]
326
+
327
+ # perform guidance
328
+ if do_classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_text - noise_pred_uncond
332
+ )
333
+
334
+ # compute the previous noisy sample x_t -> x_t-1
335
+ latents = self.scheduler.step(
336
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337
+ )[0]
338
+
339
+ # call the callback, if provided
340
+ if i == len(timesteps) - 1 or (
341
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342
+ ):
343
+ progress_bar.update()
344
+ if callback is not None and i % callback_steps == 0:
345
+ step_idx = i // getattr(self.scheduler, "order", 1)
346
+ callback(step_idx, t, latents)
347
+ reference_control_reader.clear()
348
+ reference_control_writer.clear()
349
+
350
+ # Post-processing
351
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
352
+
353
+ # Convert to tensor
354
+ if output_type == "tensor":
355
+ image = torch.from_numpy(image)
356
+
357
+ if not return_dict:
358
+ return image
359
+
360
+ return Pose2ImagePipelineOutput(images=image)
musepose/pipelines/pipeline_pose2vid.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
12
+ PNDMScheduler)
13
+ from diffusers.utils import BaseOutput, is_accelerate_available
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from einops import rearrange
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor
18
+
19
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
20
+
21
+
22
+ @dataclass
23
+ class Pose2VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+
26
+
27
+ class Pose2VideoPipeline(DiffusionPipeline):
28
+ _optional_components = []
29
+
30
+ def __init__(
31
+ self,
32
+ vae,
33
+ image_encoder,
34
+ reference_unet,
35
+ denoising_unet,
36
+ pose_guider,
37
+ scheduler: Union[
38
+ DDIMScheduler,
39
+ PNDMScheduler,
40
+ LMSDiscreteScheduler,
41
+ EulerDiscreteScheduler,
42
+ EulerAncestralDiscreteScheduler,
43
+ DPMSolverMultistepScheduler,
44
+ ],
45
+ image_proj_model=None,
46
+ tokenizer=None,
47
+ text_encoder=None,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.register_modules(
52
+ vae=vae,
53
+ image_encoder=image_encoder,
54
+ reference_unet=reference_unet,
55
+ denoising_unet=denoising_unet,
56
+ pose_guider=pose_guider,
57
+ scheduler=scheduler,
58
+ image_proj_model=image_proj_model,
59
+ tokenizer=tokenizer,
60
+ text_encoder=text_encoder,
61
+ )
62
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63
+ self.clip_image_processor = CLIPImageProcessor()
64
+ self.ref_image_processor = VaeImageProcessor(
65
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66
+ )
67
+ self.cond_image_processor = VaeImageProcessor(
68
+ vae_scale_factor=self.vae_scale_factor,
69
+ do_convert_rgb=True,
70
+ do_normalize=False,
71
+ )
72
+
73
+ def enable_vae_slicing(self):
74
+ self.vae.enable_slicing()
75
+
76
+ def disable_vae_slicing(self):
77
+ self.vae.disable_slicing()
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ if is_accelerate_available():
81
+ from accelerate import cpu_offload
82
+ else:
83
+ raise ImportError("Please install accelerate via `pip install accelerate`")
84
+
85
+ device = torch.device(f"cuda:{gpu_id}")
86
+
87
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88
+ if cpu_offloaded_model is not None:
89
+ cpu_offload(cpu_offloaded_model, device)
90
+
91
+ @property
92
+ def _execution_device(self):
93
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94
+ return self.device
95
+ for module in self.unet.modules():
96
+ if (
97
+ hasattr(module, "_hf_hook")
98
+ and hasattr(module._hf_hook, "execution_device")
99
+ and module._hf_hook.execution_device is not None
100
+ ):
101
+ return torch.device(module._hf_hook.execution_device)
102
+ return self.device
103
+
104
+ def decode_latents(self, latents):
105
+ video_length = latents.shape[2]
106
+ latents = 1 / 0.18215 * latents
107
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
108
+ # video = self.vae.decode(latents).sample
109
+ video = []
110
+ for frame_idx in tqdm(range(latents.shape[0])):
111
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112
+ video = torch.cat(video)
113
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114
+ video = (video / 2 + 0.5).clamp(0, 1)
115
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116
+ video = video.cpu().float().numpy()
117
+ return video
118
+
119
+ def prepare_extra_step_kwargs(self, generator, eta):
120
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123
+ # and should be between [0, 1]
124
+
125
+ accepts_eta = "eta" in set(
126
+ inspect.signature(self.scheduler.step).parameters.keys()
127
+ )
128
+ extra_step_kwargs = {}
129
+ if accepts_eta:
130
+ extra_step_kwargs["eta"] = eta
131
+
132
+ # check if the scheduler accepts generator
133
+ accepts_generator = "generator" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ if accepts_generator:
137
+ extra_step_kwargs["generator"] = generator
138
+ return extra_step_kwargs
139
+
140
+ def prepare_latents(
141
+ self,
142
+ batch_size,
143
+ num_channels_latents,
144
+ width,
145
+ height,
146
+ video_length,
147
+ dtype,
148
+ device,
149
+ generator,
150
+ latents=None,
151
+ ):
152
+ shape = (
153
+ batch_size,
154
+ num_channels_latents,
155
+ video_length,
156
+ height // self.vae_scale_factor,
157
+ width // self.vae_scale_factor,
158
+ )
159
+ if isinstance(generator, list) and len(generator) != batch_size:
160
+ raise ValueError(
161
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163
+ )
164
+
165
+ if latents is None:
166
+ latents = randn_tensor(
167
+ shape, generator=generator, device=device, dtype=dtype
168
+ )
169
+ else:
170
+ latents = latents.to(device)
171
+
172
+ # scale the initial noise by the standard deviation required by the scheduler
173
+ latents = latents * self.scheduler.init_noise_sigma
174
+ return latents
175
+
176
+ def _encode_prompt(
177
+ self,
178
+ prompt,
179
+ device,
180
+ num_videos_per_prompt,
181
+ do_classifier_free_guidance,
182
+ negative_prompt,
183
+ ):
184
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
185
+
186
+ text_inputs = self.tokenizer(
187
+ prompt,
188
+ padding="max_length",
189
+ max_length=self.tokenizer.model_max_length,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+ text_input_ids = text_inputs.input_ids
194
+ untruncated_ids = self.tokenizer(
195
+ prompt, padding="longest", return_tensors="pt"
196
+ ).input_ids
197
+
198
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199
+ text_input_ids, untruncated_ids
200
+ ):
201
+ removed_text = self.tokenizer.batch_decode(
202
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203
+ )
204
+
205
+ if (
206
+ hasattr(self.text_encoder.config, "use_attention_mask")
207
+ and self.text_encoder.config.use_attention_mask
208
+ ):
209
+ attention_mask = text_inputs.attention_mask.to(device)
210
+ else:
211
+ attention_mask = None
212
+
213
+ text_embeddings = self.text_encoder(
214
+ text_input_ids.to(device),
215
+ attention_mask=attention_mask,
216
+ )
217
+ text_embeddings = text_embeddings[0]
218
+
219
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
220
+ bs_embed, seq_len, _ = text_embeddings.shape
221
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222
+ text_embeddings = text_embeddings.view(
223
+ bs_embed * num_videos_per_prompt, seq_len, -1
224
+ )
225
+
226
+ # get unconditional embeddings for classifier free guidance
227
+ if do_classifier_free_guidance:
228
+ uncond_tokens: List[str]
229
+ if negative_prompt is None:
230
+ uncond_tokens = [""] * batch_size
231
+ elif type(prompt) is not type(negative_prompt):
232
+ raise TypeError(
233
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234
+ f" {type(prompt)}."
235
+ )
236
+ elif isinstance(negative_prompt, str):
237
+ uncond_tokens = [negative_prompt]
238
+ elif batch_size != len(negative_prompt):
239
+ raise ValueError(
240
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242
+ " the batch size of `prompt`."
243
+ )
244
+ else:
245
+ uncond_tokens = negative_prompt
246
+
247
+ max_length = text_input_ids.shape[-1]
248
+ uncond_input = self.tokenizer(
249
+ uncond_tokens,
250
+ padding="max_length",
251
+ max_length=max_length,
252
+ truncation=True,
253
+ return_tensors="pt",
254
+ )
255
+
256
+ if (
257
+ hasattr(self.text_encoder.config, "use_attention_mask")
258
+ and self.text_encoder.config.use_attention_mask
259
+ ):
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(
274
+ batch_size * num_videos_per_prompt, seq_len, -1
275
+ )
276
+
277
+ # For classifier free guidance, we need to do two forward passes.
278
+ # Here we concatenate the unconditional and text embeddings into a single batch
279
+ # to avoid doing two forward passes
280
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281
+
282
+ return text_embeddings
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ ref_image,
288
+ pose_images,
289
+ width,
290
+ height,
291
+ video_length,
292
+ num_inference_steps,
293
+ guidance_scale,
294
+ num_images_per_prompt=1,
295
+ eta: float = 0.0,
296
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297
+ output_type: Optional[str] = "tensor",
298
+ return_dict: bool = True,
299
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300
+ callback_steps: Optional[int] = 1,
301
+ **kwargs,
302
+ ):
303
+ # Default height and width to unet
304
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
305
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
306
+
307
+ device = self._execution_device
308
+
309
+ do_classifier_free_guidance = guidance_scale > 1.0
310
+
311
+ # Prepare timesteps
312
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
313
+ timesteps = self.scheduler.timesteps
314
+
315
+ batch_size = 1
316
+
317
+ # Prepare clip image embeds
318
+ clip_image = self.clip_image_processor.preprocess(
319
+ ref_image, return_tensors="pt"
320
+ ).pixel_values
321
+ clip_image_embeds = self.image_encoder(
322
+ clip_image.to(device, dtype=self.image_encoder.dtype)
323
+ ).image_embeds
324
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326
+
327
+ if do_classifier_free_guidance:
328
+ encoder_hidden_states = torch.cat(
329
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330
+ )
331
+ reference_control_writer = ReferenceAttentionControl(
332
+ self.reference_unet,
333
+ do_classifier_free_guidance=do_classifier_free_guidance,
334
+ mode="write",
335
+ batch_size=batch_size,
336
+ fusion_blocks="full",
337
+ )
338
+ reference_control_reader = ReferenceAttentionControl(
339
+ self.denoising_unet,
340
+ do_classifier_free_guidance=do_classifier_free_guidance,
341
+ mode="read",
342
+ batch_size=batch_size,
343
+ fusion_blocks="full",
344
+ )
345
+
346
+ num_channels_latents = self.denoising_unet.in_channels
347
+ latents = self.prepare_latents(
348
+ batch_size * num_images_per_prompt,
349
+ num_channels_latents,
350
+ width,
351
+ height,
352
+ video_length,
353
+ clip_image_embeds.dtype,
354
+ device,
355
+ generator,
356
+ )
357
+
358
+ # Prepare extra step kwargs.
359
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360
+
361
+ # Prepare ref image latents
362
+ ref_image_tensor = self.ref_image_processor.preprocess(
363
+ ref_image, height=height, width=width
364
+ ) # (bs, c, width, height)
365
+ ref_image_tensor = ref_image_tensor.to(
366
+ dtype=self.vae.dtype, device=self.vae.device
367
+ )
368
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370
+
371
+ # Prepare a list of pose condition images
372
+ pose_cond_tensor_list = []
373
+ for pose_image in pose_images:
374
+ pose_cond_tensor = (
375
+ torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376
+ )
377
+ pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378
+ 1
379
+ ) # (c, 1, h, w)
380
+ pose_cond_tensor_list.append(pose_cond_tensor)
381
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383
+ pose_cond_tensor = pose_cond_tensor.to(
384
+ device=device, dtype=self.pose_guider.dtype
385
+ )
386
+ pose_fea = self.pose_guider(pose_cond_tensor)
387
+ pose_fea = (
388
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389
+ )
390
+
391
+ # denoising loop
392
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
394
+ for i, t in enumerate(timesteps):
395
+ # 1. Forward reference image
396
+ if i == 0:
397
+ self.reference_unet(
398
+ ref_image_latents.repeat(
399
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
400
+ ),
401
+ torch.zeros_like(t),
402
+ # t,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ return_dict=False,
405
+ )
406
+ reference_control_reader.update(reference_control_writer)
407
+
408
+ # 3.1 expand the latents if we are doing classifier free guidance
409
+ latent_model_input = (
410
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411
+ )
412
+ latent_model_input = self.scheduler.scale_model_input(
413
+ latent_model_input, t
414
+ )
415
+
416
+ noise_pred = self.denoising_unet(
417
+ latent_model_input,
418
+ t,
419
+ encoder_hidden_states=encoder_hidden_states,
420
+ pose_cond_fea=pose_fea,
421
+ return_dict=False,
422
+ )[0]
423
+
424
+ # perform guidance
425
+ if do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + guidance_scale * (
428
+ noise_pred_text - noise_pred_uncond
429
+ )
430
+
431
+ # compute the previous noisy sample x_t -> x_t-1
432
+ latents = self.scheduler.step(
433
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434
+ )[0]
435
+
436
+ # call the callback, if provided
437
+ if i == len(timesteps) - 1 or (
438
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439
+ ):
440
+ progress_bar.update()
441
+ if callback is not None and i % callback_steps == 0:
442
+ step_idx = i // getattr(self.scheduler, "order", 1)
443
+ callback(step_idx, t, latents)
444
+
445
+ reference_control_reader.clear()
446
+ reference_control_writer.clear()
447
+
448
+ # Post-processing
449
+ images = self.decode_latents(latents) # (b, c, f, h, w)
450
+
451
+ # Convert to tensor
452
+ if output_type == "tensor":
453
+ images = torch.from_numpy(images)
454
+
455
+ if not return_dict:
456
+ return images
457
+
458
+ return Pose2VideoPipelineOutput(videos=images)
musepose/pipelines/pipeline_pose2vid_long.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+ from transformers import CLIPImageProcessor
24
+
25
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
26
+ from musepose.pipelines.context import get_context_scheduler
27
+ from musepose.pipelines.utils import get_tensor_interpolation_method
28
+
29
+
30
+ @dataclass
31
+ class Pose2VideoPipelineOutput(BaseOutput):
32
+ videos: Union[torch.Tensor, np.ndarray]
33
+
34
+
35
+ class Pose2VideoPipeline(DiffusionPipeline):
36
+ _optional_components = []
37
+
38
+ def __init__(
39
+ self,
40
+ vae,
41
+ image_encoder,
42
+ reference_unet,
43
+ denoising_unet,
44
+ pose_guider,
45
+ scheduler: Union[
46
+ DDIMScheduler,
47
+ PNDMScheduler,
48
+ LMSDiscreteScheduler,
49
+ EulerDiscreteScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ DPMSolverMultistepScheduler,
52
+ ],
53
+ image_proj_model=None,
54
+ tokenizer=None,
55
+ text_encoder=None,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.register_modules(
60
+ vae=vae,
61
+ image_encoder=image_encoder,
62
+ reference_unet=reference_unet,
63
+ denoising_unet=denoising_unet,
64
+ pose_guider=pose_guider,
65
+ scheduler=scheduler,
66
+ image_proj_model=image_proj_model,
67
+ tokenizer=tokenizer,
68
+ text_encoder=text_encoder,
69
+ )
70
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
71
+ self.clip_image_processor = CLIPImageProcessor()
72
+ self.ref_image_processor = VaeImageProcessor(
73
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
74
+ )
75
+ self.cond_image_processor = VaeImageProcessor(
76
+ vae_scale_factor=self.vae_scale_factor,
77
+ do_convert_rgb=True,
78
+ do_normalize=False,
79
+ )
80
+
81
+ def enable_vae_slicing(self):
82
+ self.vae.enable_slicing()
83
+
84
+ def disable_vae_slicing(self):
85
+ self.vae.disable_slicing()
86
+
87
+ def enable_sequential_cpu_offload(self, gpu_id=0):
88
+ if is_accelerate_available():
89
+ from accelerate import cpu_offload
90
+ else:
91
+ raise ImportError("Please install accelerate via `pip install accelerate`")
92
+
93
+ device = torch.device(f"cuda:{gpu_id}")
94
+
95
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
96
+ if cpu_offloaded_model is not None:
97
+ cpu_offload(cpu_offloaded_model, device)
98
+
99
+ @property
100
+ def _execution_device(self):
101
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
102
+ return self.device
103
+ for module in self.unet.modules():
104
+ if (
105
+ hasattr(module, "_hf_hook")
106
+ and hasattr(module._hf_hook, "execution_device")
107
+ and module._hf_hook.execution_device is not None
108
+ ):
109
+ return torch.device(module._hf_hook.execution_device)
110
+ return self.device
111
+
112
+ def decode_latents(self, latents):
113
+ video_length = latents.shape[2]
114
+ latents = 1 / 0.18215 * latents
115
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
116
+ # video = self.vae.decode(latents).sample
117
+ video = []
118
+ for frame_idx in tqdm(range(latents.shape[0])):
119
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
120
+ video = torch.cat(video)
121
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
122
+ video = (video / 2 + 0.5).clamp(0, 1)
123
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
124
+ video = video.cpu().float().numpy()
125
+ return video
126
+
127
+ def prepare_extra_step_kwargs(self, generator, eta):
128
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
129
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
130
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
131
+ # and should be between [0, 1]
132
+
133
+ accepts_eta = "eta" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ extra_step_kwargs = {}
137
+ if accepts_eta:
138
+ extra_step_kwargs["eta"] = eta
139
+
140
+ # check if the scheduler accepts generator
141
+ accepts_generator = "generator" in set(
142
+ inspect.signature(self.scheduler.step).parameters.keys()
143
+ )
144
+ if accepts_generator:
145
+ extra_step_kwargs["generator"] = generator
146
+ return extra_step_kwargs
147
+
148
+ def prepare_latents(
149
+ self,
150
+ batch_size,
151
+ num_channels_latents,
152
+ width,
153
+ height,
154
+ video_length,
155
+ dtype,
156
+ device,
157
+ generator,
158
+ latents=None,
159
+ ):
160
+ shape = (
161
+ batch_size,
162
+ num_channels_latents,
163
+ video_length,
164
+ height // self.vae_scale_factor,
165
+ width // self.vae_scale_factor,
166
+ )
167
+ if isinstance(generator, list) and len(generator) != batch_size:
168
+ raise ValueError(
169
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
170
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
171
+ )
172
+
173
+ if latents is None:
174
+ latents = randn_tensor(
175
+ shape, generator=generator, device=device, dtype=dtype
176
+ )
177
+ else:
178
+ latents = latents.to(device)
179
+
180
+ # scale the initial noise by the standard deviation required by the scheduler
181
+ latents = latents * self.scheduler.init_noise_sigma
182
+ return latents
183
+
184
+ def _encode_prompt(
185
+ self,
186
+ prompt,
187
+ device,
188
+ num_videos_per_prompt,
189
+ do_classifier_free_guidance,
190
+ negative_prompt,
191
+ ):
192
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
193
+
194
+ text_inputs = self.tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=self.tokenizer.model_max_length,
198
+ truncation=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids = text_inputs.input_ids
202
+ untruncated_ids = self.tokenizer(
203
+ prompt, padding="longest", return_tensors="pt"
204
+ ).input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
207
+ text_input_ids, untruncated_ids
208
+ ):
209
+ removed_text = self.tokenizer.batch_decode(
210
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
211
+ )
212
+
213
+ if (
214
+ hasattr(self.text_encoder.config, "use_attention_mask")
215
+ and self.text_encoder.config.use_attention_mask
216
+ ):
217
+ attention_mask = text_inputs.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ text_embeddings = self.text_encoder(
222
+ text_input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ text_embeddings = text_embeddings[0]
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ bs_embed, seq_len, _ = text_embeddings.shape
229
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ text_embeddings = text_embeddings.view(
231
+ bs_embed * num_videos_per_prompt, seq_len, -1
232
+ )
233
+
234
+ # get unconditional embeddings for classifier free guidance
235
+ if do_classifier_free_guidance:
236
+ uncond_tokens: List[str]
237
+ if negative_prompt is None:
238
+ uncond_tokens = [""] * batch_size
239
+ elif type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif isinstance(negative_prompt, str):
245
+ uncond_tokens = [negative_prompt]
246
+ elif batch_size != len(negative_prompt):
247
+ raise ValueError(
248
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
249
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
250
+ " the batch size of `prompt`."
251
+ )
252
+ else:
253
+ uncond_tokens = negative_prompt
254
+
255
+ max_length = text_input_ids.shape[-1]
256
+ uncond_input = self.tokenizer(
257
+ uncond_tokens,
258
+ padding="max_length",
259
+ max_length=max_length,
260
+ truncation=True,
261
+ return_tensors="pt",
262
+ )
263
+
264
+ if (
265
+ hasattr(self.text_encoder.config, "use_attention_mask")
266
+ and self.text_encoder.config.use_attention_mask
267
+ ):
268
+ attention_mask = uncond_input.attention_mask.to(device)
269
+ else:
270
+ attention_mask = None
271
+
272
+ uncond_embeddings = self.text_encoder(
273
+ uncond_input.input_ids.to(device),
274
+ attention_mask=attention_mask,
275
+ )
276
+ uncond_embeddings = uncond_embeddings[0]
277
+
278
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
279
+ seq_len = uncond_embeddings.shape[1]
280
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
281
+ uncond_embeddings = uncond_embeddings.view(
282
+ batch_size * num_videos_per_prompt, seq_len, -1
283
+ )
284
+
285
+ # For classifier free guidance, we need to do two forward passes.
286
+ # Here we concatenate the unconditional and text embeddings into a single batch
287
+ # to avoid doing two forward passes
288
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
289
+
290
+ return text_embeddings
291
+
292
+ def interpolate_latents(
293
+ self, latents: torch.Tensor, interpolation_factor: int, device
294
+ ):
295
+ if interpolation_factor < 2:
296
+ return latents
297
+
298
+ new_latents = torch.zeros(
299
+ (
300
+ latents.shape[0],
301
+ latents.shape[1],
302
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
303
+ latents.shape[3],
304
+ latents.shape[4],
305
+ ),
306
+ device=latents.device,
307
+ dtype=latents.dtype,
308
+ )
309
+
310
+ org_video_length = latents.shape[2]
311
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
312
+
313
+ new_index = 0
314
+
315
+ v0 = None
316
+ v1 = None
317
+
318
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
319
+ v0 = latents[:, :, i0, :, :]
320
+ v1 = latents[:, :, i1, :, :]
321
+
322
+ new_latents[:, :, new_index, :, :] = v0
323
+ new_index += 1
324
+
325
+ for f in rate:
326
+ v = get_tensor_interpolation_method()(
327
+ v0.to(device=device), v1.to(device=device), f
328
+ )
329
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
330
+ new_index += 1
331
+
332
+ new_latents[:, :, new_index, :, :] = v1
333
+ new_index += 1
334
+
335
+ return new_latents
336
+
337
+ @torch.no_grad()
338
+ def __call__(
339
+ self,
340
+ ref_image,
341
+ pose_images,
342
+ width,
343
+ height,
344
+ video_length,
345
+ num_inference_steps,
346
+ guidance_scale,
347
+ num_images_per_prompt=1,
348
+ eta: float = 0.0,
349
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
350
+ output_type: Optional[str] = "tensor",
351
+ return_dict: bool = True,
352
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
353
+ callback_steps: Optional[int] = 1,
354
+ context_schedule="uniform",
355
+ context_frames=24,
356
+ context_stride=1,
357
+ context_overlap=4,
358
+ context_batch_size=1,
359
+ interpolation_factor=1,
360
+ **kwargs,
361
+ ):
362
+ # Default height and width to unet
363
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
364
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
365
+
366
+ device = self._execution_device
367
+
368
+ do_classifier_free_guidance = guidance_scale > 1.0
369
+
370
+ # Prepare timesteps
371
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
372
+ timesteps = self.scheduler.timesteps
373
+
374
+ batch_size = 1
375
+
376
+ # Prepare clip image embeds
377
+ clip_image = self.clip_image_processor.preprocess(
378
+ ref_image.resize((224, 224)), return_tensors="pt"
379
+ ).pixel_values
380
+ clip_image_embeds = self.image_encoder(
381
+ clip_image.to(device, dtype=self.image_encoder.dtype)
382
+ ).image_embeds
383
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
384
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
385
+
386
+ if do_classifier_free_guidance:
387
+ encoder_hidden_states = torch.cat(
388
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
389
+ )
390
+
391
+ reference_control_writer = ReferenceAttentionControl(
392
+ self.reference_unet,
393
+ do_classifier_free_guidance=do_classifier_free_guidance,
394
+ mode="write",
395
+ batch_size=batch_size,
396
+ fusion_blocks="full",
397
+ )
398
+ reference_control_reader = ReferenceAttentionControl(
399
+ self.denoising_unet,
400
+ do_classifier_free_guidance=do_classifier_free_guidance,
401
+ mode="read",
402
+ batch_size=batch_size,
403
+ fusion_blocks="full",
404
+ )
405
+
406
+ num_channels_latents = self.denoising_unet.in_channels
407
+ latents = self.prepare_latents(
408
+ batch_size * num_images_per_prompt,
409
+ num_channels_latents,
410
+ width,
411
+ height,
412
+ video_length,
413
+ clip_image_embeds.dtype,
414
+ device,
415
+ generator,
416
+ )
417
+
418
+ # Prepare extra step kwargs.
419
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
420
+
421
+ # Prepare ref image latents
422
+ ref_image_tensor = self.ref_image_processor.preprocess(
423
+ ref_image, height=height, width=width
424
+ ) # (bs, c, width, height)
425
+ ref_image_tensor = ref_image_tensor.to(
426
+ dtype=self.vae.dtype, device=self.vae.device
427
+ )
428
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
429
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
430
+
431
+ # Prepare a list of pose condition images
432
+ pose_cond_tensor_list = []
433
+ for pose_image in pose_images:
434
+ pose_cond_tensor = self.cond_image_processor.preprocess(
435
+ pose_image, height=height, width=width
436
+ )
437
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
438
+ pose_cond_tensor_list.append(pose_cond_tensor)
439
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)
440
+ pose_cond_tensor = pose_cond_tensor.to(
441
+ device=device, dtype=self.pose_guider.dtype
442
+ )
443
+ pose_fea = self.pose_guider(pose_cond_tensor)
444
+
445
+ context_scheduler = get_context_scheduler(context_schedule)
446
+
447
+ # denoising loop
448
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
449
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
450
+ for i, t in enumerate(timesteps):
451
+ noise_pred = torch.zeros(
452
+ (
453
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
454
+ *latents.shape[1:],
455
+ ),
456
+ device=latents.device,
457
+ dtype=latents.dtype,
458
+ )
459
+ counter = torch.zeros(
460
+ (1, 1, latents.shape[2], 1, 1),
461
+ device=latents.device,
462
+ dtype=latents.dtype,
463
+ )
464
+
465
+ # 1. Forward reference image
466
+ if i == 0:
467
+ self.reference_unet(
468
+ ref_image_latents.repeat(
469
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
470
+ ),
471
+ torch.zeros_like(t),
472
+ # t,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ return_dict=False,
475
+ )
476
+ reference_control_reader.update(reference_control_writer)
477
+
478
+ context_queue = list(
479
+ context_scheduler(
480
+ 0,
481
+ num_inference_steps,
482
+ latents.shape[2],
483
+ context_frames,
484
+ context_stride,
485
+ 0,
486
+ )
487
+ )
488
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
489
+
490
+ context_queue = list(
491
+ context_scheduler(
492
+ 0,
493
+ num_inference_steps,
494
+ latents.shape[2],
495
+ context_frames,
496
+ context_stride,
497
+ context_overlap,
498
+ )
499
+ )
500
+
501
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
502
+ global_context = []
503
+ for i in range(num_context_batches):
504
+ global_context.append(
505
+ context_queue[
506
+ i * context_batch_size : (i + 1) * context_batch_size
507
+ ]
508
+ )
509
+
510
+ for context in global_context:
511
+ # 3.1 expand the latents if we are doing classifier free guidance
512
+ latent_model_input = (
513
+ torch.cat([latents[:, :, c] for c in context])
514
+ .to(device)
515
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
516
+ )
517
+ latent_model_input = self.scheduler.scale_model_input(
518
+ latent_model_input, t
519
+ )
520
+ b, c, f, h, w = latent_model_input.shape
521
+ latent_pose_input = torch.cat(
522
+ [pose_fea[:, :, c] for c in context]
523
+ ).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
524
+
525
+ pred = self.denoising_unet(
526
+ latent_model_input,
527
+ t,
528
+ encoder_hidden_states=encoder_hidden_states[:b],
529
+ pose_cond_fea=latent_pose_input,
530
+ return_dict=False,
531
+ )[0]
532
+
533
+ for j, c in enumerate(context):
534
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
535
+ counter[:, :, c] = counter[:, :, c] + 1
536
+
537
+ # perform guidance
538
+ if do_classifier_free_guidance:
539
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
540
+ noise_pred = noise_pred_uncond + guidance_scale * (
541
+ noise_pred_text - noise_pred_uncond
542
+ )
543
+
544
+ latents = self.scheduler.step(
545
+ noise_pred, t, latents, **extra_step_kwargs
546
+ ).prev_sample
547
+
548
+ if i == len(timesteps) - 1 or (
549
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
550
+ ):
551
+ progress_bar.update()
552
+ if callback is not None and i % callback_steps == 0:
553
+ step_idx = i // getattr(self.scheduler, "order", 1)
554
+ callback(step_idx, t, latents)
555
+
556
+ reference_control_reader.clear()
557
+ reference_control_writer.clear()
558
+
559
+ if interpolation_factor > 0:
560
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
561
+ # Post-processing
562
+ images = self.decode_latents(latents) # (b, c, f, h, w)
563
+
564
+ # Convert to tensor
565
+ if output_type == "tensor":
566
+ images = torch.from_numpy(images)
567
+
568
+ if not return_dict:
569
+ return images
570
+
571
+ return Pose2VideoPipelineOutput(videos=images)
musepose/pipelines/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ tensor_interpolation = None
4
+
5
+
6
+ def get_tensor_interpolation_method():
7
+ return tensor_interpolation
8
+
9
+
10
+ def set_tensor_interpolation_method(is_slerp):
11
+ global tensor_interpolation
12
+ tensor_interpolation = slerp if is_slerp else linear
13
+
14
+
15
+ def linear(v1, v2, t):
16
+ return (1.0 - t) * v1 + t * v2
17
+
18
+
19
+ def slerp(
20
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21
+ ) -> torch.Tensor:
22
+ u0 = v0 / v0.norm()
23
+ u1 = v1 / v1.norm()
24
+ dot = (u0 * u1).sum()
25
+ if dot.abs() > DOT_THRESHOLD:
26
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27
+ return (1.0 - t) * v0 + t * v1
28
+ omega = dot.acos()
29
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
musepose/utils/util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+
26
+
27
+ def import_filename(filename):
28
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
29
+ module = importlib.util.module_from_spec(spec)
30
+ sys.modules[spec.name] = module
31
+ spec.loader.exec_module(module)
32
+ return module
33
+
34
+
35
+ def delete_additional_ckpt(base_path, num_keep):
36
+ dirs = []
37
+ for d in os.listdir(base_path):
38
+ if d.startswith("checkpoint-"):
39
+ dirs.append(d)
40
+ num_tot = len(dirs)
41
+ if num_tot <= num_keep:
42
+ return
43
+ # ensure ckpt is sorted and delete the ealier!
44
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
+ for d in del_dirs:
46
+ path_to_dir = osp.join(base_path, d)
47
+ if osp.exists(path_to_dir):
48
+ shutil.rmtree(path_to_dir)
49
+
50
+
51
+ def save_videos_from_pil(pil_images, path, fps=8):
52
+ import av
53
+
54
+ save_fmt = Path(path).suffix
55
+ os.makedirs(os.path.dirname(path), exist_ok=True)
56
+ width, height = pil_images[0].size
57
+
58
+ if save_fmt == ".mp4":
59
+ codec = "libx264"
60
+ container = av.open(path, "w")
61
+ stream = container.add_stream(codec, rate=fps)
62
+
63
+ stream.width = width
64
+ stream.height = height
65
+ stream.pix_fmt = 'yuv420p'
66
+ stream.bit_rate = 10000000
67
+ stream.options["crf"] = "18"
68
+
69
+
70
+
71
+ for pil_image in pil_images:
72
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
73
+ av_frame = av.VideoFrame.from_image(pil_image)
74
+ container.mux(stream.encode(av_frame))
75
+ container.mux(stream.encode())
76
+ container.close()
77
+
78
+ elif save_fmt == ".gif":
79
+ pil_images[0].save(
80
+ fp=path,
81
+ format="GIF",
82
+ append_images=pil_images[1:],
83
+ save_all=True,
84
+ duration=(1 / fps * 1000),
85
+ loop=0,
86
+ )
87
+ else:
88
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
89
+
90
+
91
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
92
+ videos = rearrange(videos, "b c t h w -> t b c h w")
93
+ height, width = videos.shape[-2:]
94
+ outputs = []
95
+
96
+ for x in videos:
97
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
98
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
99
+ if rescale:
100
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
101
+ x = (x * 255).numpy().astype(np.uint8)
102
+ x = Image.fromarray(x)
103
+
104
+ outputs.append(x)
105
+
106
+ os.makedirs(os.path.dirname(path), exist_ok=True)
107
+
108
+ save_videos_from_pil(outputs, path, fps)
109
+
110
+
111
+ def read_frames(video_path):
112
+ container = av.open(video_path)
113
+
114
+ video_stream = next(s for s in container.streams if s.type == "video")
115
+ frames = []
116
+ for packet in container.demux(video_stream):
117
+ for frame in packet.decode():
118
+ image = Image.frombytes(
119
+ "RGB",
120
+ (frame.width, frame.height),
121
+ frame.to_rgb().to_ndarray(),
122
+ )
123
+ frames.append(image)
124
+
125
+ return frames
126
+
127
+
128
+ def get_fps(video_path):
129
+ container = av.open(video_path)
130
+ video_stream = next(s for s in container.streams if s.type == "video")
131
+ fps = video_stream.average_rate
132
+ container.close()
133
+ return fps
pipelines/__init__.py ADDED
File without changes
pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = False,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
pipelines/pipeline_pose2img.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+ from diffusers.utils import BaseOutput, is_accelerate_available
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
24
+
25
+
26
+ @dataclass
27
+ class Pose2ImagePipelineOutput(BaseOutput):
28
+ images: Union[torch.Tensor, np.ndarray]
29
+
30
+
31
+ class Pose2ImagePipeline(DiffusionPipeline):
32
+ _optional_components = []
33
+
34
+ def __init__(
35
+ self,
36
+ vae,
37
+ image_encoder,
38
+ reference_unet,
39
+ denoising_unet,
40
+ pose_guider,
41
+ scheduler: Union[
42
+ DDIMScheduler,
43
+ PNDMScheduler,
44
+ LMSDiscreteScheduler,
45
+ EulerDiscreteScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ DPMSolverMultistepScheduler,
48
+ ],
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ image_encoder=image_encoder,
55
+ reference_unet=reference_unet,
56
+ denoising_unet=denoising_unet,
57
+ pose_guider=pose_guider,
58
+ scheduler=scheduler,
59
+ )
60
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61
+ self.clip_image_processor = CLIPImageProcessor()
62
+ self.ref_image_processor = VaeImageProcessor(
63
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64
+ )
65
+ self.cond_image_processor = VaeImageProcessor(
66
+ vae_scale_factor=self.vae_scale_factor,
67
+ do_convert_rgb=True,
68
+ do_normalize=False,
69
+ )
70
+
71
+ def enable_vae_slicing(self):
72
+ self.vae.enable_slicing()
73
+
74
+ def disable_vae_slicing(self):
75
+ self.vae.disable_slicing()
76
+
77
+ def enable_sequential_cpu_offload(self, gpu_id=0):
78
+ if is_accelerate_available():
79
+ from accelerate import cpu_offload
80
+ else:
81
+ raise ImportError("Please install accelerate via `pip install accelerate`")
82
+
83
+ device = torch.device(f"cuda:{gpu_id}")
84
+
85
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86
+ if cpu_offloaded_model is not None:
87
+ cpu_offload(cpu_offloaded_model, device)
88
+
89
+ @property
90
+ def _execution_device(self):
91
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92
+ return self.device
93
+ for module in self.unet.modules():
94
+ if (
95
+ hasattr(module, "_hf_hook")
96
+ and hasattr(module._hf_hook, "execution_device")
97
+ and module._hf_hook.execution_device is not None
98
+ ):
99
+ return torch.device(module._hf_hook.execution_device)
100
+ return self.device
101
+
102
+ def decode_latents(self, latents):
103
+ video_length = latents.shape[2]
104
+ latents = 1 / 0.18215 * latents
105
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
106
+ # video = self.vae.decode(latents).sample
107
+ video = []
108
+ for frame_idx in tqdm(range(latents.shape[0])):
109
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110
+ video = torch.cat(video)
111
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112
+ video = (video / 2 + 0.5).clamp(0, 1)
113
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114
+ video = video.cpu().float().numpy()
115
+ return video
116
+
117
+ def prepare_extra_step_kwargs(self, generator, eta):
118
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121
+ # and should be between [0, 1]
122
+
123
+ accepts_eta = "eta" in set(
124
+ inspect.signature(self.scheduler.step).parameters.keys()
125
+ )
126
+ extra_step_kwargs = {}
127
+ if accepts_eta:
128
+ extra_step_kwargs["eta"] = eta
129
+
130
+ # check if the scheduler accepts generator
131
+ accepts_generator = "generator" in set(
132
+ inspect.signature(self.scheduler.step).parameters.keys()
133
+ )
134
+ if accepts_generator:
135
+ extra_step_kwargs["generator"] = generator
136
+ return extra_step_kwargs
137
+
138
+ def prepare_latents(
139
+ self,
140
+ batch_size,
141
+ num_channels_latents,
142
+ width,
143
+ height,
144
+ dtype,
145
+ device,
146
+ generator,
147
+ latents=None,
148
+ ):
149
+ shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ height // self.vae_scale_factor,
153
+ width // self.vae_scale_factor,
154
+ )
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ if latents is None:
162
+ latents = randn_tensor(
163
+ shape, generator=generator, device=device, dtype=dtype
164
+ )
165
+ else:
166
+ latents = latents.to(device)
167
+
168
+ # scale the initial noise by the standard deviation required by the scheduler
169
+ latents = latents * self.scheduler.init_noise_sigma
170
+ return latents
171
+
172
+ def prepare_condition(
173
+ self,
174
+ cond_image,
175
+ width,
176
+ height,
177
+ device,
178
+ dtype,
179
+ do_classififer_free_guidance=False,
180
+ ):
181
+ image = self.cond_image_processor.preprocess(
182
+ cond_image, height=height, width=width
183
+ ).to(dtype=torch.float32)
184
+
185
+ image = image.to(device=device, dtype=dtype)
186
+
187
+ if do_classififer_free_guidance:
188
+ image = torch.cat([image] * 2)
189
+
190
+ return image
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ ref_image,
196
+ pose_image,
197
+ width,
198
+ height,
199
+ num_inference_steps,
200
+ guidance_scale,
201
+ num_images_per_prompt=1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ output_type: Optional[str] = "tensor",
205
+ return_dict: bool = True,
206
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207
+ callback_steps: Optional[int] = 1,
208
+ **kwargs,
209
+ ):
210
+ # Default height and width to unet
211
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
212
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
213
+
214
+ device = self._execution_device
215
+
216
+ do_classifier_free_guidance = guidance_scale > 1.0
217
+
218
+ # Prepare timesteps
219
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
220
+ timesteps = self.scheduler.timesteps
221
+
222
+ batch_size = 1
223
+
224
+ # Prepare clip image embeds
225
+ clip_image = self.clip_image_processor.preprocess(
226
+ ref_image.resize((224, 224)), return_tensors="pt"
227
+ ).pixel_values
228
+ clip_image_embeds = self.image_encoder(
229
+ clip_image.to(device, dtype=self.image_encoder.dtype)
230
+ ).image_embeds
231
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233
+
234
+ if do_classifier_free_guidance:
235
+ image_prompt_embeds = torch.cat(
236
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237
+ )
238
+
239
+ reference_control_writer = ReferenceAttentionControl(
240
+ self.reference_unet,
241
+ do_classifier_free_guidance=do_classifier_free_guidance,
242
+ mode="write",
243
+ batch_size=batch_size,
244
+ fusion_blocks="full",
245
+ )
246
+ reference_control_reader = ReferenceAttentionControl(
247
+ self.denoising_unet,
248
+ do_classifier_free_guidance=do_classifier_free_guidance,
249
+ mode="read",
250
+ batch_size=batch_size,
251
+ fusion_blocks="full",
252
+ )
253
+
254
+ num_channels_latents = self.denoising_unet.in_channels
255
+ latents = self.prepare_latents(
256
+ batch_size * num_images_per_prompt,
257
+ num_channels_latents,
258
+ width,
259
+ height,
260
+ clip_image_embeds.dtype,
261
+ device,
262
+ generator,
263
+ )
264
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265
+ latents_dtype = latents.dtype
266
+
267
+ # Prepare extra step kwargs.
268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269
+
270
+ # Prepare ref image latents
271
+ ref_image_tensor = self.ref_image_processor.preprocess(
272
+ ref_image, height=height, width=width
273
+ ) # (bs, c, width, height)
274
+ ref_image_tensor = ref_image_tensor.to(
275
+ dtype=self.vae.dtype, device=self.vae.device
276
+ )
277
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279
+
280
+ # Prepare pose condition image
281
+ pose_cond_tensor = self.cond_image_processor.preprocess(
282
+ pose_image, height=height, width=width
283
+ )
284
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285
+ pose_cond_tensor = pose_cond_tensor.to(
286
+ device=device, dtype=self.pose_guider.dtype
287
+ )
288
+ pose_fea = self.pose_guider(pose_cond_tensor)
289
+ pose_fea = (
290
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291
+ )
292
+
293
+ # denoising loop
294
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
296
+ for i, t in enumerate(timesteps):
297
+ # 1. Forward reference image
298
+ if i == 0:
299
+ self.reference_unet(
300
+ ref_image_latents.repeat(
301
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
302
+ ),
303
+ torch.zeros_like(t),
304
+ encoder_hidden_states=image_prompt_embeds,
305
+ return_dict=False,
306
+ )
307
+
308
+ # 2. Update reference unet feature into denosing net
309
+ reference_control_reader.update(reference_control_writer)
310
+
311
+ # 3.1 expand the latents if we are doing classifier free guidance
312
+ latent_model_input = (
313
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314
+ )
315
+ latent_model_input = self.scheduler.scale_model_input(
316
+ latent_model_input, t
317
+ )
318
+
319
+ noise_pred = self.denoising_unet(
320
+ latent_model_input,
321
+ t,
322
+ encoder_hidden_states=image_prompt_embeds,
323
+ pose_cond_fea=pose_fea,
324
+ return_dict=False,
325
+ )[0]
326
+
327
+ # perform guidance
328
+ if do_classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_text - noise_pred_uncond
332
+ )
333
+
334
+ # compute the previous noisy sample x_t -> x_t-1
335
+ latents = self.scheduler.step(
336
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337
+ )[0]
338
+
339
+ # call the callback, if provided
340
+ if i == len(timesteps) - 1 or (
341
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342
+ ):
343
+ progress_bar.update()
344
+ if callback is not None and i % callback_steps == 0:
345
+ step_idx = i // getattr(self.scheduler, "order", 1)
346
+ callback(step_idx, t, latents)
347
+ reference_control_reader.clear()
348
+ reference_control_writer.clear()
349
+
350
+ # Post-processing
351
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
352
+
353
+ # Convert to tensor
354
+ if output_type == "tensor":
355
+ image = torch.from_numpy(image)
356
+
357
+ if not return_dict:
358
+ return image
359
+
360
+ return Pose2ImagePipelineOutput(images=image)
pipelines/pipeline_pose2vid.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
12
+ PNDMScheduler)
13
+ from diffusers.utils import BaseOutput, is_accelerate_available
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from einops import rearrange
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor
18
+
19
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
20
+
21
+
22
+ @dataclass
23
+ class Pose2VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+
26
+
27
+ class Pose2VideoPipeline(DiffusionPipeline):
28
+ _optional_components = []
29
+
30
+ def __init__(
31
+ self,
32
+ vae,
33
+ image_encoder,
34
+ reference_unet,
35
+ denoising_unet,
36
+ pose_guider,
37
+ scheduler: Union[
38
+ DDIMScheduler,
39
+ PNDMScheduler,
40
+ LMSDiscreteScheduler,
41
+ EulerDiscreteScheduler,
42
+ EulerAncestralDiscreteScheduler,
43
+ DPMSolverMultistepScheduler,
44
+ ],
45
+ image_proj_model=None,
46
+ tokenizer=None,
47
+ text_encoder=None,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.register_modules(
52
+ vae=vae,
53
+ image_encoder=image_encoder,
54
+ reference_unet=reference_unet,
55
+ denoising_unet=denoising_unet,
56
+ pose_guider=pose_guider,
57
+ scheduler=scheduler,
58
+ image_proj_model=image_proj_model,
59
+ tokenizer=tokenizer,
60
+ text_encoder=text_encoder,
61
+ )
62
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63
+ self.clip_image_processor = CLIPImageProcessor()
64
+ self.ref_image_processor = VaeImageProcessor(
65
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66
+ )
67
+ self.cond_image_processor = VaeImageProcessor(
68
+ vae_scale_factor=self.vae_scale_factor,
69
+ do_convert_rgb=True,
70
+ do_normalize=False,
71
+ )
72
+
73
+ def enable_vae_slicing(self):
74
+ self.vae.enable_slicing()
75
+
76
+ def disable_vae_slicing(self):
77
+ self.vae.disable_slicing()
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ if is_accelerate_available():
81
+ from accelerate import cpu_offload
82
+ else:
83
+ raise ImportError("Please install accelerate via `pip install accelerate`")
84
+
85
+ device = torch.device(f"cuda:{gpu_id}")
86
+
87
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88
+ if cpu_offloaded_model is not None:
89
+ cpu_offload(cpu_offloaded_model, device)
90
+
91
+ @property
92
+ def _execution_device(self):
93
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94
+ return self.device
95
+ for module in self.unet.modules():
96
+ if (
97
+ hasattr(module, "_hf_hook")
98
+ and hasattr(module._hf_hook, "execution_device")
99
+ and module._hf_hook.execution_device is not None
100
+ ):
101
+ return torch.device(module._hf_hook.execution_device)
102
+ return self.device
103
+
104
+ def decode_latents(self, latents):
105
+ video_length = latents.shape[2]
106
+ latents = 1 / 0.18215 * latents
107
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
108
+ # video = self.vae.decode(latents).sample
109
+ video = []
110
+ for frame_idx in tqdm(range(latents.shape[0])):
111
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112
+ video = torch.cat(video)
113
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114
+ video = (video / 2 + 0.5).clamp(0, 1)
115
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116
+ video = video.cpu().float().numpy()
117
+ return video
118
+
119
+ def prepare_extra_step_kwargs(self, generator, eta):
120
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123
+ # and should be between [0, 1]
124
+
125
+ accepts_eta = "eta" in set(
126
+ inspect.signature(self.scheduler.step).parameters.keys()
127
+ )
128
+ extra_step_kwargs = {}
129
+ if accepts_eta:
130
+ extra_step_kwargs["eta"] = eta
131
+
132
+ # check if the scheduler accepts generator
133
+ accepts_generator = "generator" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ if accepts_generator:
137
+ extra_step_kwargs["generator"] = generator
138
+ return extra_step_kwargs
139
+
140
+ def prepare_latents(
141
+ self,
142
+ batch_size,
143
+ num_channels_latents,
144
+ width,
145
+ height,
146
+ video_length,
147
+ dtype,
148
+ device,
149
+ generator,
150
+ latents=None,
151
+ ):
152
+ shape = (
153
+ batch_size,
154
+ num_channels_latents,
155
+ video_length,
156
+ height // self.vae_scale_factor,
157
+ width // self.vae_scale_factor,
158
+ )
159
+ if isinstance(generator, list) and len(generator) != batch_size:
160
+ raise ValueError(
161
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163
+ )
164
+
165
+ if latents is None:
166
+ latents = randn_tensor(
167
+ shape, generator=generator, device=device, dtype=dtype
168
+ )
169
+ else:
170
+ latents = latents.to(device)
171
+
172
+ # scale the initial noise by the standard deviation required by the scheduler
173
+ latents = latents * self.scheduler.init_noise_sigma
174
+ return latents
175
+
176
+ def _encode_prompt(
177
+ self,
178
+ prompt,
179
+ device,
180
+ num_videos_per_prompt,
181
+ do_classifier_free_guidance,
182
+ negative_prompt,
183
+ ):
184
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
185
+
186
+ text_inputs = self.tokenizer(
187
+ prompt,
188
+ padding="max_length",
189
+ max_length=self.tokenizer.model_max_length,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+ text_input_ids = text_inputs.input_ids
194
+ untruncated_ids = self.tokenizer(
195
+ prompt, padding="longest", return_tensors="pt"
196
+ ).input_ids
197
+
198
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199
+ text_input_ids, untruncated_ids
200
+ ):
201
+ removed_text = self.tokenizer.batch_decode(
202
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203
+ )
204
+
205
+ if (
206
+ hasattr(self.text_encoder.config, "use_attention_mask")
207
+ and self.text_encoder.config.use_attention_mask
208
+ ):
209
+ attention_mask = text_inputs.attention_mask.to(device)
210
+ else:
211
+ attention_mask = None
212
+
213
+ text_embeddings = self.text_encoder(
214
+ text_input_ids.to(device),
215
+ attention_mask=attention_mask,
216
+ )
217
+ text_embeddings = text_embeddings[0]
218
+
219
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
220
+ bs_embed, seq_len, _ = text_embeddings.shape
221
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222
+ text_embeddings = text_embeddings.view(
223
+ bs_embed * num_videos_per_prompt, seq_len, -1
224
+ )
225
+
226
+ # get unconditional embeddings for classifier free guidance
227
+ if do_classifier_free_guidance:
228
+ uncond_tokens: List[str]
229
+ if negative_prompt is None:
230
+ uncond_tokens = [""] * batch_size
231
+ elif type(prompt) is not type(negative_prompt):
232
+ raise TypeError(
233
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234
+ f" {type(prompt)}."
235
+ )
236
+ elif isinstance(negative_prompt, str):
237
+ uncond_tokens = [negative_prompt]
238
+ elif batch_size != len(negative_prompt):
239
+ raise ValueError(
240
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242
+ " the batch size of `prompt`."
243
+ )
244
+ else:
245
+ uncond_tokens = negative_prompt
246
+
247
+ max_length = text_input_ids.shape[-1]
248
+ uncond_input = self.tokenizer(
249
+ uncond_tokens,
250
+ padding="max_length",
251
+ max_length=max_length,
252
+ truncation=True,
253
+ return_tensors="pt",
254
+ )
255
+
256
+ if (
257
+ hasattr(self.text_encoder.config, "use_attention_mask")
258
+ and self.text_encoder.config.use_attention_mask
259
+ ):
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(
274
+ batch_size * num_videos_per_prompt, seq_len, -1
275
+ )
276
+
277
+ # For classifier free guidance, we need to do two forward passes.
278
+ # Here we concatenate the unconditional and text embeddings into a single batch
279
+ # to avoid doing two forward passes
280
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281
+
282
+ return text_embeddings
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ ref_image,
288
+ pose_images,
289
+ width,
290
+ height,
291
+ video_length,
292
+ num_inference_steps,
293
+ guidance_scale,
294
+ num_images_per_prompt=1,
295
+ eta: float = 0.0,
296
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297
+ output_type: Optional[str] = "tensor",
298
+ return_dict: bool = True,
299
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300
+ callback_steps: Optional[int] = 1,
301
+ **kwargs,
302
+ ):
303
+ # Default height and width to unet
304
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
305
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
306
+
307
+ device = self._execution_device
308
+
309
+ do_classifier_free_guidance = guidance_scale > 1.0
310
+
311
+ # Prepare timesteps
312
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
313
+ timesteps = self.scheduler.timesteps
314
+
315
+ batch_size = 1
316
+
317
+ # Prepare clip image embeds
318
+ clip_image = self.clip_image_processor.preprocess(
319
+ ref_image, return_tensors="pt"
320
+ ).pixel_values
321
+ clip_image_embeds = self.image_encoder(
322
+ clip_image.to(device, dtype=self.image_encoder.dtype)
323
+ ).image_embeds
324
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326
+
327
+ if do_classifier_free_guidance:
328
+ encoder_hidden_states = torch.cat(
329
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330
+ )
331
+ reference_control_writer = ReferenceAttentionControl(
332
+ self.reference_unet,
333
+ do_classifier_free_guidance=do_classifier_free_guidance,
334
+ mode="write",
335
+ batch_size=batch_size,
336
+ fusion_blocks="full",
337
+ )
338
+ reference_control_reader = ReferenceAttentionControl(
339
+ self.denoising_unet,
340
+ do_classifier_free_guidance=do_classifier_free_guidance,
341
+ mode="read",
342
+ batch_size=batch_size,
343
+ fusion_blocks="full",
344
+ )
345
+
346
+ num_channels_latents = self.denoising_unet.in_channels
347
+ latents = self.prepare_latents(
348
+ batch_size * num_images_per_prompt,
349
+ num_channels_latents,
350
+ width,
351
+ height,
352
+ video_length,
353
+ clip_image_embeds.dtype,
354
+ device,
355
+ generator,
356
+ )
357
+
358
+ # Prepare extra step kwargs.
359
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360
+
361
+ # Prepare ref image latents
362
+ ref_image_tensor = self.ref_image_processor.preprocess(
363
+ ref_image, height=height, width=width
364
+ ) # (bs, c, width, height)
365
+ ref_image_tensor = ref_image_tensor.to(
366
+ dtype=self.vae.dtype, device=self.vae.device
367
+ )
368
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370
+
371
+ # Prepare a list of pose condition images
372
+ pose_cond_tensor_list = []
373
+ for pose_image in pose_images:
374
+ pose_cond_tensor = (
375
+ torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376
+ )
377
+ pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378
+ 1
379
+ ) # (c, 1, h, w)
380
+ pose_cond_tensor_list.append(pose_cond_tensor)
381
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383
+ pose_cond_tensor = pose_cond_tensor.to(
384
+ device=device, dtype=self.pose_guider.dtype
385
+ )
386
+ pose_fea = self.pose_guider(pose_cond_tensor)
387
+ pose_fea = (
388
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389
+ )
390
+
391
+ # denoising loop
392
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
394
+ for i, t in enumerate(timesteps):
395
+ # 1. Forward reference image
396
+ if i == 0:
397
+ self.reference_unet(
398
+ ref_image_latents.repeat(
399
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
400
+ ),
401
+ torch.zeros_like(t),
402
+ # t,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ return_dict=False,
405
+ )
406
+ reference_control_reader.update(reference_control_writer)
407
+
408
+ # 3.1 expand the latents if we are doing classifier free guidance
409
+ latent_model_input = (
410
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411
+ )
412
+ latent_model_input = self.scheduler.scale_model_input(
413
+ latent_model_input, t
414
+ )
415
+
416
+ noise_pred = self.denoising_unet(
417
+ latent_model_input,
418
+ t,
419
+ encoder_hidden_states=encoder_hidden_states,
420
+ pose_cond_fea=pose_fea,
421
+ return_dict=False,
422
+ )[0]
423
+
424
+ # perform guidance
425
+ if do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + guidance_scale * (
428
+ noise_pred_text - noise_pred_uncond
429
+ )
430
+
431
+ # compute the previous noisy sample x_t -> x_t-1
432
+ latents = self.scheduler.step(
433
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434
+ )[0]
435
+
436
+ # call the callback, if provided
437
+ if i == len(timesteps) - 1 or (
438
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439
+ ):
440
+ progress_bar.update()
441
+ if callback is not None and i % callback_steps == 0:
442
+ step_idx = i // getattr(self.scheduler, "order", 1)
443
+ callback(step_idx, t, latents)
444
+
445
+ reference_control_reader.clear()
446
+ reference_control_writer.clear()
447
+
448
+ # Post-processing
449
+ images = self.decode_latents(latents) # (b, c, f, h, w)
450
+
451
+ # Convert to tensor
452
+ if output_type == "tensor":
453
+ images = torch.from_numpy(images)
454
+
455
+ if not return_dict:
456
+ return images
457
+
458
+ return Pose2VideoPipelineOutput(videos=images)
pipelines/pipeline_pose2vid_long.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+ from transformers import CLIPImageProcessor
24
+
25
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
26
+ from musepose.pipelines.context import get_context_scheduler
27
+ from musepose.pipelines.utils import get_tensor_interpolation_method
28
+
29
+
30
+ @dataclass
31
+ class Pose2VideoPipelineOutput(BaseOutput):
32
+ videos: Union[torch.Tensor, np.ndarray]
33
+
34
+
35
+ class Pose2VideoPipeline(DiffusionPipeline):
36
+ _optional_components = []
37
+
38
+ def __init__(
39
+ self,
40
+ vae,
41
+ image_encoder,
42
+ reference_unet,
43
+ denoising_unet,
44
+ pose_guider,
45
+ scheduler: Union[
46
+ DDIMScheduler,
47
+ PNDMScheduler,
48
+ LMSDiscreteScheduler,
49
+ EulerDiscreteScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ DPMSolverMultistepScheduler,
52
+ ],
53
+ image_proj_model=None,
54
+ tokenizer=None,
55
+ text_encoder=None,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.register_modules(
60
+ vae=vae,
61
+ image_encoder=image_encoder,
62
+ reference_unet=reference_unet,
63
+ denoising_unet=denoising_unet,
64
+ pose_guider=pose_guider,
65
+ scheduler=scheduler,
66
+ image_proj_model=image_proj_model,
67
+ tokenizer=tokenizer,
68
+ text_encoder=text_encoder,
69
+ )
70
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
71
+ self.clip_image_processor = CLIPImageProcessor()
72
+ self.ref_image_processor = VaeImageProcessor(
73
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
74
+ )
75
+ self.cond_image_processor = VaeImageProcessor(
76
+ vae_scale_factor=self.vae_scale_factor,
77
+ do_convert_rgb=True,
78
+ do_normalize=False,
79
+ )
80
+
81
+ def enable_vae_slicing(self):
82
+ self.vae.enable_slicing()
83
+
84
+ def disable_vae_slicing(self):
85
+ self.vae.disable_slicing()
86
+
87
+ def enable_sequential_cpu_offload(self, gpu_id=0):
88
+ if is_accelerate_available():
89
+ from accelerate import cpu_offload
90
+ else:
91
+ raise ImportError("Please install accelerate via `pip install accelerate`")
92
+
93
+ device = torch.device(f"cuda:{gpu_id}")
94
+
95
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
96
+ if cpu_offloaded_model is not None:
97
+ cpu_offload(cpu_offloaded_model, device)
98
+
99
+ @property
100
+ def _execution_device(self):
101
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
102
+ return self.device
103
+ for module in self.unet.modules():
104
+ if (
105
+ hasattr(module, "_hf_hook")
106
+ and hasattr(module._hf_hook, "execution_device")
107
+ and module._hf_hook.execution_device is not None
108
+ ):
109
+ return torch.device(module._hf_hook.execution_device)
110
+ return self.device
111
+
112
+ def decode_latents(self, latents):
113
+ video_length = latents.shape[2]
114
+ latents = 1 / 0.18215 * latents
115
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
116
+ # video = self.vae.decode(latents).sample
117
+ video = []
118
+ for frame_idx in tqdm(range(latents.shape[0])):
119
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
120
+ video = torch.cat(video)
121
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
122
+ video = (video / 2 + 0.5).clamp(0, 1)
123
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
124
+ video = video.cpu().float().numpy()
125
+ return video
126
+
127
+ def prepare_extra_step_kwargs(self, generator, eta):
128
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
129
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
130
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
131
+ # and should be between [0, 1]
132
+
133
+ accepts_eta = "eta" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ extra_step_kwargs = {}
137
+ if accepts_eta:
138
+ extra_step_kwargs["eta"] = eta
139
+
140
+ # check if the scheduler accepts generator
141
+ accepts_generator = "generator" in set(
142
+ inspect.signature(self.scheduler.step).parameters.keys()
143
+ )
144
+ if accepts_generator:
145
+ extra_step_kwargs["generator"] = generator
146
+ return extra_step_kwargs
147
+
148
+ def prepare_latents(
149
+ self,
150
+ batch_size,
151
+ num_channels_latents,
152
+ width,
153
+ height,
154
+ video_length,
155
+ dtype,
156
+ device,
157
+ generator,
158
+ latents=None,
159
+ ):
160
+ shape = (
161
+ batch_size,
162
+ num_channels_latents,
163
+ video_length,
164
+ height // self.vae_scale_factor,
165
+ width // self.vae_scale_factor,
166
+ )
167
+ if isinstance(generator, list) and len(generator) != batch_size:
168
+ raise ValueError(
169
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
170
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
171
+ )
172
+
173
+ if latents is None:
174
+ latents = randn_tensor(
175
+ shape, generator=generator, device=device, dtype=dtype
176
+ )
177
+ else:
178
+ latents = latents.to(device)
179
+
180
+ # scale the initial noise by the standard deviation required by the scheduler
181
+ latents = latents * self.scheduler.init_noise_sigma
182
+ return latents
183
+
184
+ def _encode_prompt(
185
+ self,
186
+ prompt,
187
+ device,
188
+ num_videos_per_prompt,
189
+ do_classifier_free_guidance,
190
+ negative_prompt,
191
+ ):
192
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
193
+
194
+ text_inputs = self.tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=self.tokenizer.model_max_length,
198
+ truncation=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids = text_inputs.input_ids
202
+ untruncated_ids = self.tokenizer(
203
+ prompt, padding="longest", return_tensors="pt"
204
+ ).input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
207
+ text_input_ids, untruncated_ids
208
+ ):
209
+ removed_text = self.tokenizer.batch_decode(
210
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
211
+ )
212
+
213
+ if (
214
+ hasattr(self.text_encoder.config, "use_attention_mask")
215
+ and self.text_encoder.config.use_attention_mask
216
+ ):
217
+ attention_mask = text_inputs.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ text_embeddings = self.text_encoder(
222
+ text_input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ text_embeddings = text_embeddings[0]
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ bs_embed, seq_len, _ = text_embeddings.shape
229
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ text_embeddings = text_embeddings.view(
231
+ bs_embed * num_videos_per_prompt, seq_len, -1
232
+ )
233
+
234
+ # get unconditional embeddings for classifier free guidance
235
+ if do_classifier_free_guidance:
236
+ uncond_tokens: List[str]
237
+ if negative_prompt is None:
238
+ uncond_tokens = [""] * batch_size
239
+ elif type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif isinstance(negative_prompt, str):
245
+ uncond_tokens = [negative_prompt]
246
+ elif batch_size != len(negative_prompt):
247
+ raise ValueError(
248
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
249
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
250
+ " the batch size of `prompt`."
251
+ )
252
+ else:
253
+ uncond_tokens = negative_prompt
254
+
255
+ max_length = text_input_ids.shape[-1]
256
+ uncond_input = self.tokenizer(
257
+ uncond_tokens,
258
+ padding="max_length",
259
+ max_length=max_length,
260
+ truncation=True,
261
+ return_tensors="pt",
262
+ )
263
+
264
+ if (
265
+ hasattr(self.text_encoder.config, "use_attention_mask")
266
+ and self.text_encoder.config.use_attention_mask
267
+ ):
268
+ attention_mask = uncond_input.attention_mask.to(device)
269
+ else:
270
+ attention_mask = None
271
+
272
+ uncond_embeddings = self.text_encoder(
273
+ uncond_input.input_ids.to(device),
274
+ attention_mask=attention_mask,
275
+ )
276
+ uncond_embeddings = uncond_embeddings[0]
277
+
278
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
279
+ seq_len = uncond_embeddings.shape[1]
280
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
281
+ uncond_embeddings = uncond_embeddings.view(
282
+ batch_size * num_videos_per_prompt, seq_len, -1
283
+ )
284
+
285
+ # For classifier free guidance, we need to do two forward passes.
286
+ # Here we concatenate the unconditional and text embeddings into a single batch
287
+ # to avoid doing two forward passes
288
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
289
+
290
+ return text_embeddings
291
+
292
+ def interpolate_latents(
293
+ self, latents: torch.Tensor, interpolation_factor: int, device
294
+ ):
295
+ if interpolation_factor < 2:
296
+ return latents
297
+
298
+ new_latents = torch.zeros(
299
+ (
300
+ latents.shape[0],
301
+ latents.shape[1],
302
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
303
+ latents.shape[3],
304
+ latents.shape[4],
305
+ ),
306
+ device=latents.device,
307
+ dtype=latents.dtype,
308
+ )
309
+
310
+ org_video_length = latents.shape[2]
311
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
312
+
313
+ new_index = 0
314
+
315
+ v0 = None
316
+ v1 = None
317
+
318
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
319
+ v0 = latents[:, :, i0, :, :]
320
+ v1 = latents[:, :, i1, :, :]
321
+
322
+ new_latents[:, :, new_index, :, :] = v0
323
+ new_index += 1
324
+
325
+ for f in rate:
326
+ v = get_tensor_interpolation_method()(
327
+ v0.to(device=device), v1.to(device=device), f
328
+ )
329
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
330
+ new_index += 1
331
+
332
+ new_latents[:, :, new_index, :, :] = v1
333
+ new_index += 1
334
+
335
+ return new_latents
336
+
337
+ @torch.no_grad()
338
+ def __call__(
339
+ self,
340
+ ref_image,
341
+ pose_images,
342
+ width,
343
+ height,
344
+ video_length,
345
+ num_inference_steps,
346
+ guidance_scale,
347
+ num_images_per_prompt=1,
348
+ eta: float = 0.0,
349
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
350
+ output_type: Optional[str] = "tensor",
351
+ return_dict: bool = True,
352
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
353
+ callback_steps: Optional[int] = 1,
354
+ context_schedule="uniform",
355
+ context_frames=24,
356
+ context_stride=1,
357
+ context_overlap=4,
358
+ context_batch_size=1,
359
+ interpolation_factor=1,
360
+ **kwargs,
361
+ ):
362
+ # Default height and width to unet
363
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
364
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
365
+
366
+ device = self._execution_device
367
+
368
+ do_classifier_free_guidance = guidance_scale > 1.0
369
+
370
+ # Prepare timesteps
371
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
372
+ timesteps = self.scheduler.timesteps
373
+
374
+ batch_size = 1
375
+
376
+ # Prepare clip image embeds
377
+ clip_image = self.clip_image_processor.preprocess(
378
+ ref_image.resize((224, 224)), return_tensors="pt"
379
+ ).pixel_values
380
+ clip_image_embeds = self.image_encoder(
381
+ clip_image.to(device, dtype=self.image_encoder.dtype)
382
+ ).image_embeds
383
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
384
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
385
+
386
+ if do_classifier_free_guidance:
387
+ encoder_hidden_states = torch.cat(
388
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
389
+ )
390
+
391
+ reference_control_writer = ReferenceAttentionControl(
392
+ self.reference_unet,
393
+ do_classifier_free_guidance=do_classifier_free_guidance,
394
+ mode="write",
395
+ batch_size=batch_size,
396
+ fusion_blocks="full",
397
+ )
398
+ reference_control_reader = ReferenceAttentionControl(
399
+ self.denoising_unet,
400
+ do_classifier_free_guidance=do_classifier_free_guidance,
401
+ mode="read",
402
+ batch_size=batch_size,
403
+ fusion_blocks="full",
404
+ )
405
+
406
+ num_channels_latents = self.denoising_unet.in_channels
407
+ latents = self.prepare_latents(
408
+ batch_size * num_images_per_prompt,
409
+ num_channels_latents,
410
+ width,
411
+ height,
412
+ video_length,
413
+ clip_image_embeds.dtype,
414
+ device,
415
+ generator,
416
+ )
417
+
418
+ # Prepare extra step kwargs.
419
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
420
+
421
+ # Prepare ref image latents
422
+ ref_image_tensor = self.ref_image_processor.preprocess(
423
+ ref_image, height=height, width=width
424
+ ) # (bs, c, width, height)
425
+ ref_image_tensor = ref_image_tensor.to(
426
+ dtype=self.vae.dtype, device=self.vae.device
427
+ )
428
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
429
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
430
+
431
+ # Prepare a list of pose condition images
432
+ pose_cond_tensor_list = []
433
+ for pose_image in pose_images:
434
+ pose_cond_tensor = self.cond_image_processor.preprocess(
435
+ pose_image, height=height, width=width
436
+ )
437
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
438
+ pose_cond_tensor_list.append(pose_cond_tensor)
439
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)
440
+ pose_cond_tensor = pose_cond_tensor.to(
441
+ device=device, dtype=self.pose_guider.dtype
442
+ )
443
+ pose_fea = self.pose_guider(pose_cond_tensor)
444
+
445
+ context_scheduler = get_context_scheduler(context_schedule)
446
+
447
+ # denoising loop
448
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
449
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
450
+ for i, t in enumerate(timesteps):
451
+ noise_pred = torch.zeros(
452
+ (
453
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
454
+ *latents.shape[1:],
455
+ ),
456
+ device=latents.device,
457
+ dtype=latents.dtype,
458
+ )
459
+ counter = torch.zeros(
460
+ (1, 1, latents.shape[2], 1, 1),
461
+ device=latents.device,
462
+ dtype=latents.dtype,
463
+ )
464
+
465
+ # 1. Forward reference image
466
+ if i == 0:
467
+ self.reference_unet(
468
+ ref_image_latents.repeat(
469
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
470
+ ),
471
+ torch.zeros_like(t),
472
+ # t,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ return_dict=False,
475
+ )
476
+ reference_control_reader.update(reference_control_writer)
477
+
478
+ context_queue = list(
479
+ context_scheduler(
480
+ 0,
481
+ num_inference_steps,
482
+ latents.shape[2],
483
+ context_frames,
484
+ context_stride,
485
+ 0,
486
+ )
487
+ )
488
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
489
+
490
+ context_queue = list(
491
+ context_scheduler(
492
+ 0,
493
+ num_inference_steps,
494
+ latents.shape[2],
495
+ context_frames,
496
+ context_stride,
497
+ context_overlap,
498
+ )
499
+ )
500
+
501
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
502
+ global_context = []
503
+ for i in range(num_context_batches):
504
+ global_context.append(
505
+ context_queue[
506
+ i * context_batch_size : (i + 1) * context_batch_size
507
+ ]
508
+ )
509
+
510
+ for context in global_context:
511
+ # 3.1 expand the latents if we are doing classifier free guidance
512
+ latent_model_input = (
513
+ torch.cat([latents[:, :, c] for c in context])
514
+ .to(device)
515
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
516
+ )
517
+ latent_model_input = self.scheduler.scale_model_input(
518
+ latent_model_input, t
519
+ )
520
+ b, c, f, h, w = latent_model_input.shape
521
+ latent_pose_input = torch.cat(
522
+ [pose_fea[:, :, c] for c in context]
523
+ ).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
524
+
525
+ pred = self.denoising_unet(
526
+ latent_model_input,
527
+ t,
528
+ encoder_hidden_states=encoder_hidden_states[:b],
529
+ pose_cond_fea=latent_pose_input,
530
+ return_dict=False,
531
+ )[0]
532
+
533
+ for j, c in enumerate(context):
534
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
535
+ counter[:, :, c] = counter[:, :, c] + 1
536
+
537
+ # perform guidance
538
+ if do_classifier_free_guidance:
539
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
540
+ noise_pred = noise_pred_uncond + guidance_scale * (
541
+ noise_pred_text - noise_pred_uncond
542
+ )
543
+
544
+ latents = self.scheduler.step(
545
+ noise_pred, t, latents, **extra_step_kwargs
546
+ ).prev_sample
547
+
548
+ if i == len(timesteps) - 1 or (
549
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
550
+ ):
551
+ progress_bar.update()
552
+ if callback is not None and i % callback_steps == 0:
553
+ step_idx = i // getattr(self.scheduler, "order", 1)
554
+ callback(step_idx, t, latents)
555
+
556
+ reference_control_reader.clear()
557
+ reference_control_writer.clear()
558
+
559
+ if interpolation_factor > 0:
560
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
561
+ # Post-processing
562
+ images = self.decode_latents(latents) # (b, c, f, h, w)
563
+
564
+ # Convert to tensor
565
+ if output_type == "tensor":
566
+ images = torch.from_numpy(images)
567
+
568
+ if not return_dict:
569
+ return images
570
+
571
+ return Pose2VideoPipelineOutput(videos=images)
pipelines/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ tensor_interpolation = None
4
+
5
+
6
+ def get_tensor_interpolation_method():
7
+ return tensor_interpolation
8
+
9
+
10
+ def set_tensor_interpolation_method(is_slerp):
11
+ global tensor_interpolation
12
+ tensor_interpolation = slerp if is_slerp else linear
13
+
14
+
15
+ def linear(v1, v2, t):
16
+ return (1.0 - t) * v1 + t * v2
17
+
18
+
19
+ def slerp(
20
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21
+ ) -> torch.Tensor:
22
+ u0 = v0 / v0.norm()
23
+ u1 = v1 / v1.norm()
24
+ dot = (u0 * u1).sum()
25
+ if dot.abs() > DOT_THRESHOLD:
26
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27
+ return (1.0 - t) * v0 + t * v1
28
+ omega = dot.acos()
29
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
utils/util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+
26
+
27
+ def import_filename(filename):
28
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
29
+ module = importlib.util.module_from_spec(spec)
30
+ sys.modules[spec.name] = module
31
+ spec.loader.exec_module(module)
32
+ return module
33
+
34
+
35
+ def delete_additional_ckpt(base_path, num_keep):
36
+ dirs = []
37
+ for d in os.listdir(base_path):
38
+ if d.startswith("checkpoint-"):
39
+ dirs.append(d)
40
+ num_tot = len(dirs)
41
+ if num_tot <= num_keep:
42
+ return
43
+ # ensure ckpt is sorted and delete the ealier!
44
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
+ for d in del_dirs:
46
+ path_to_dir = osp.join(base_path, d)
47
+ if osp.exists(path_to_dir):
48
+ shutil.rmtree(path_to_dir)
49
+
50
+
51
+ def save_videos_from_pil(pil_images, path, fps=8):
52
+ import av
53
+
54
+ save_fmt = Path(path).suffix
55
+ os.makedirs(os.path.dirname(path), exist_ok=True)
56
+ width, height = pil_images[0].size
57
+
58
+ if save_fmt == ".mp4":
59
+ codec = "libx264"
60
+ container = av.open(path, "w")
61
+ stream = container.add_stream(codec, rate=fps)
62
+
63
+ stream.width = width
64
+ stream.height = height
65
+ stream.pix_fmt = 'yuv420p'
66
+ stream.bit_rate = 10000000
67
+ stream.options["crf"] = "18"
68
+
69
+
70
+
71
+ for pil_image in pil_images:
72
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
73
+ av_frame = av.VideoFrame.from_image(pil_image)
74
+ container.mux(stream.encode(av_frame))
75
+ container.mux(stream.encode())
76
+ container.close()
77
+
78
+ elif save_fmt == ".gif":
79
+ pil_images[0].save(
80
+ fp=path,
81
+ format="GIF",
82
+ append_images=pil_images[1:],
83
+ save_all=True,
84
+ duration=(1 / fps * 1000),
85
+ loop=0,
86
+ )
87
+ else:
88
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
89
+
90
+
91
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
92
+ videos = rearrange(videos, "b c t h w -> t b c h w")
93
+ height, width = videos.shape[-2:]
94
+ outputs = []
95
+
96
+ for x in videos:
97
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
98
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
99
+ if rescale:
100
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
101
+ x = (x * 255).numpy().astype(np.uint8)
102
+ x = Image.fromarray(x)
103
+
104
+ outputs.append(x)
105
+
106
+ os.makedirs(os.path.dirname(path), exist_ok=True)
107
+
108
+ save_videos_from_pil(outputs, path, fps)
109
+
110
+
111
+ def read_frames(video_path):
112
+ container = av.open(video_path)
113
+
114
+ video_stream = next(s for s in container.streams if s.type == "video")
115
+ frames = []
116
+ for packet in container.demux(video_stream):
117
+ for frame in packet.decode():
118
+ image = Image.frombytes(
119
+ "RGB",
120
+ (frame.width, frame.height),
121
+ frame.to_rgb().to_ndarray(),
122
+ )
123
+ frames.append(image)
124
+
125
+ return frames
126
+
127
+
128
+ def get_fps(video_path):
129
+ container = av.open(video_path)
130
+ video_stream = next(s for s in container.streams if s.type == "video")
131
+ fps = video_stream.average_rate
132
+ container.close()
133
+ return fps