Bai-YT commited on
Commit
66982e9
0 Parent(s):

Gradio App for ConsistencyTTA V1

Browse files
Files changed (43) hide show
  1. .gitignore +6 -0
  2. README.md +7 -0
  3. audioldm/hifigan/__init__.py +7 -0
  4. audioldm/hifigan/models.py +127 -0
  5. audioldm/hifigan/utilities.py +88 -0
  6. audioldm/latent_diffusion/attention.py +469 -0
  7. audioldm/latent_diffusion/util.py +293 -0
  8. audioldm/stft.py +257 -0
  9. audioldm/utils.py +177 -0
  10. audioldm/variational_autoencoder/__init__.py +1 -0
  11. audioldm/variational_autoencoder/autoencoder.py +131 -0
  12. audioldm/variational_autoencoder/distributions.py +102 -0
  13. audioldm/variational_autoencoder/modules.py +1067 -0
  14. consistencytta.py +200 -0
  15. consistencytta_clapft_ckpt/.DS_Store +0 -0
  16. diffusers/__init__.py +2 -0
  17. diffusers/models/__init__.py +23 -0
  18. diffusers/models/activations.py +12 -0
  19. diffusers/models/attention.py +523 -0
  20. diffusers/models/attention_processor.py +1646 -0
  21. diffusers/models/dual_transformer_2d.py +151 -0
  22. diffusers/models/embeddings.py +480 -0
  23. diffusers/models/loaders.py +1481 -0
  24. diffusers/models/modeling_utils.py +978 -0
  25. diffusers/models/prior_transformer.py +194 -0
  26. diffusers/models/resnet.py +839 -0
  27. diffusers/models/transformer_2d.py +333 -0
  28. diffusers/models/unet_2d.py +315 -0
  29. diffusers/models/unet_2d_blocks.py +0 -0
  30. diffusers/models/unet_2d_condition.py +907 -0
  31. diffusers/models/unet_2d_condition_guided.py +945 -0
  32. diffusers/scheduling_heun_discrete.py +387 -0
  33. diffusers/utils/configuration_utils.py +647 -0
  34. diffusers/utils/constants.py +34 -0
  35. diffusers/utils/deprecation_utils.py +49 -0
  36. diffusers/utils/hub_utils.py +357 -0
  37. diffusers/utils/import_utils.py +649 -0
  38. diffusers/utils/logging.py +342 -0
  39. diffusers/utils/outputs.py +108 -0
  40. diffusers/utils/scheduling_utils.py +176 -0
  41. diffusers/utils/torch_utils.py +83 -0
  42. run_gradio.py +87 -0
  43. tango_diffusion_light.json +46 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__
2
+ */__pycache__
3
+ flagged
4
+ *.wav
5
+ *.pt
6
+ *.DS_Store
README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ## Gradio App for ConsistencyTTA
2
+
3
+ Required packages:
4
+ `numpy scipy torch torchaudio einops soundfile librosa transformers gradio`
5
+
6
+ To run:
7
+ `python run_gradio.py`
audioldm/hifigan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .models import Generator
2
+
3
+
4
+ class AttrDict(dict):
5
+ def __init__(self, *args, **kwargs):
6
+ super(AttrDict, self).__init__(*args, **kwargs)
7
+ self.__dict__ = self
audioldm/hifigan/models.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils.parametrizations import weight_norm
6
+ from torch.nn.utils.parametrize import remove_parametrizations
7
+
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ def init_weights(m, mean=0.0, std=0.01):
13
+ classname = m.__class__.__name__
14
+ if classname.find("Conv") != -1:
15
+ m.weight.data.normal_(mean, std)
16
+
17
+
18
+ def get_padding(kernel_size, dilation=1):
19
+ return int((kernel_size * dilation - dilation) / 2)
20
+
21
+
22
+ class ResBlock(torch.nn.Module):
23
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
24
+ super(ResBlock, self).__init__()
25
+ self.h = h
26
+ self.convs1 = nn.ModuleList([
27
+ weight_norm(Conv1d(
28
+ channels, channels, kernel_size, 1, dilation=dilation[0],
29
+ padding=get_padding(kernel_size, dilation[0]),
30
+ )),
31
+ weight_norm(Conv1d(
32
+ channels, channels, kernel_size, 1, dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]),
34
+ )),
35
+ weight_norm(Conv1d(
36
+ channels, channels, kernel_size, 1, dilation=dilation[2],
37
+ padding=get_padding(kernel_size, dilation[2]),
38
+ )),
39
+ ])
40
+ self.convs1.apply(init_weights)
41
+
42
+ self.convs2 = nn.ModuleList([
43
+ weight_norm(Conv1d(
44
+ channels, channels, kernel_size, 1, dilation=1,
45
+ padding=get_padding(kernel_size, 1),
46
+ )),
47
+ weight_norm(Conv1d(
48
+ channels, channels, kernel_size, 1, dilation=1,
49
+ padding=get_padding(kernel_size, 1),
50
+ )),
51
+ weight_norm(Conv1d(
52
+ channels, channels, kernel_size, 1, dilation=1,
53
+ padding=get_padding(kernel_size, 1),
54
+ )),
55
+ ])
56
+ self.convs2.apply(init_weights)
57
+
58
+ def forward(self, x):
59
+ for c1, c2 in zip(self.convs1, self.convs2):
60
+ xt = F.leaky_relu(x, LRELU_SLOPE)
61
+ xt = c1(xt)
62
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
63
+ xt = c2(xt)
64
+ x = xt + x
65
+ return x
66
+
67
+ def remove_weight_norm(self):
68
+ for l in self.convs1:
69
+ remove_parametrizations(l, 'weight')
70
+ for l in self.convs2:
71
+ remove_parametrizations(l, 'weight')
72
+
73
+
74
+ class Generator(torch.nn.Module):
75
+ def __init__(self, h):
76
+ super(Generator, self).__init__()
77
+ self.h = h
78
+ self.num_kernels = len(h.resblock_kernel_sizes)
79
+ self.num_upsamples = len(h.upsample_rates)
80
+ self.conv_pre = weight_norm(
81
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
82
+ )
83
+ resblock = ResBlock
84
+
85
+ self.ups = nn.ModuleList()
86
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
87
+ self.ups.append(weight_norm(ConvTranspose1d(
88
+ h.upsample_initial_channel // (2**i),
89
+ h.upsample_initial_channel // (2 ** (i + 1)),
90
+ k, u, padding=(k - u) // 2,
91
+ )))
92
+
93
+ self.resblocks = nn.ModuleList()
94
+ for i in range(len(self.ups)):
95
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
96
+ for k, d in zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes):
97
+ self.resblocks.append(resblock(h, ch, k, d))
98
+
99
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
100
+ self.ups.apply(init_weights)
101
+ self.conv_post.apply(init_weights)
102
+
103
+ def forward(self, x):
104
+ x = self.conv_pre(x)
105
+ for i in range(self.num_upsamples):
106
+ x = F.leaky_relu(x, LRELU_SLOPE)
107
+ x = self.ups[i](x)
108
+ xs = None
109
+ for j in range(self.num_kernels):
110
+ if xs is None:
111
+ xs = self.resblocks[i * self.num_kernels + j](x)
112
+ else:
113
+ xs += self.resblocks[i * self.num_kernels + j](x)
114
+ x = xs / self.num_kernels
115
+ x = F.leaky_relu(x)
116
+ x = self.conv_post(x)
117
+ x = torch.tanh(x)
118
+
119
+ return x
120
+
121
+ def remove_weight_norm(self):
122
+ for l in self.ups:
123
+ remove_parametrizations(l, 'weight')
124
+ for l in self.resblocks:
125
+ l.remove_weight_norm()
126
+ remove_parametrizations(self.conv_pre, 'weight')
127
+ remove_parametrizations(self.conv_post, 'weight')
audioldm/hifigan/utilities.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import audioldm.hifigan as hifigan
4
+
5
+
6
+ HIFIGAN_16K_64 = {
7
+ "resblock": "1",
8
+ "num_gpus": 6,
9
+ "batch_size": 16,
10
+ "learning_rate": 0.0002,
11
+ "adam_b1": 0.8,
12
+ "adam_b2": 0.99,
13
+ "lr_decay": 0.999,
14
+ "seed": 1234,
15
+ "upsample_rates": [5, 4, 2, 2, 2],
16
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
17
+ "upsample_initial_channel": 1024,
18
+ "resblock_kernel_sizes": [3, 7, 11],
19
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
20
+ "segment_size": 8192,
21
+ "num_mels": 64,
22
+ "num_freq": 1025,
23
+ "n_fft": 1024,
24
+ "hop_size": 160,
25
+ "win_size": 1024,
26
+ "sampling_rate": 16000,
27
+ "fmin": 0,
28
+ "fmax": 8000,
29
+ "fmax_for_loss": None,
30
+ "num_workers": 4,
31
+ "dist_config": {
32
+ "dist_backend": "nccl",
33
+ "dist_url": "tcp://localhost:54321",
34
+ "world_size": 1,
35
+ },
36
+ }
37
+
38
+
39
+ def get_available_checkpoint_keys(model, ckpt):
40
+ print("==> Attemp to reload from %s" % ckpt)
41
+ state_dict = torch.load(ckpt)["state_dict"]
42
+ current_state_dict = model.state_dict()
43
+ new_state_dict = {}
44
+ for k in state_dict.keys():
45
+ if (
46
+ k in current_state_dict.keys()
47
+ and current_state_dict[k].size() == state_dict[k].size()
48
+ ):
49
+ new_state_dict[k] = state_dict[k]
50
+ else:
51
+ print("==> WARNING: Skipping %s" % k)
52
+ print(
53
+ "%s out of %s keys are matched"
54
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
55
+ )
56
+ return new_state_dict
57
+
58
+
59
+ def get_param_num(model):
60
+ num_param = sum(param.numel() for param in model.parameters())
61
+ return num_param
62
+
63
+
64
+ def get_vocoder(config, device):
65
+ config = hifigan.AttrDict(HIFIGAN_16K_64)
66
+ vocoder = hifigan.Generator(config)
67
+ vocoder.eval()
68
+ vocoder.remove_weight_norm()
69
+ vocoder.to(device)
70
+ return vocoder
71
+
72
+
73
+ def vocoder_infer(mels, vocoder, allow_grad=False, lengths=None):
74
+ vocoder.eval()
75
+
76
+ if allow_grad:
77
+ wavs = vocoder(mels).squeeze(1).float()
78
+ wavs = wavs - (wavs.max() + wavs.min()) / 2
79
+ else:
80
+ with torch.no_grad():
81
+ wavs = vocoder(mels).squeeze(1).float()
82
+ wavs = wavs - (wavs.max() + wavs.min()) / 2
83
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
84
+
85
+ if lengths is not None:
86
+ wavs = wavs[:, :lengths]
87
+
88
+ return wavs
audioldm/latent_diffusion/attention.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from einops import rearrange
7
+
8
+ from audioldm.latent_diffusion.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ class CrossAttention(nn.Module):
150
+ """
151
+ ### Cross Attention Layer
152
+ This falls-back to self-attention when conditional embeddings are not specified.
153
+ """
154
+
155
+ # use_flash_attention: bool = True
156
+ use_flash_attention: bool = False
157
+
158
+ def __init__(
159
+ self,
160
+ query_dim,
161
+ context_dim=None,
162
+ heads=8,
163
+ dim_head=64,
164
+ dropout=0.0,
165
+ is_inplace: bool = True,
166
+ ):
167
+ # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
+ """
169
+ :param d_model: is the input embedding size
170
+ :param n_heads: is the number of attention heads
171
+ :param d_head: is the size of a attention head
172
+ :param d_cond: is the size of the conditional embeddings
173
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
+ save memory
175
+ """
176
+ super().__init__()
177
+
178
+ self.is_inplace = is_inplace
179
+ self.n_heads = heads
180
+ self.d_head = dim_head
181
+
182
+ # Attention scaling factor
183
+ self.scale = dim_head**-0.5
184
+
185
+ # The normal self-attention layer
186
+ if context_dim is None:
187
+ context_dim = query_dim
188
+
189
+ # Query, key and value mappings
190
+ d_attn = dim_head * heads
191
+ self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
+ self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
+ self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
+
195
+ # Final linear layer
196
+ self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
+
198
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
+ # Flash attention is only used if it's installed
200
+ # and `CrossAttention.use_flash_attention` is set to `True`.
201
+ try:
202
+ # You can install flash attention by cloning their Github repo,
203
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
+ # and then running `python setup.py install`
205
+ from flash_attn.flash_attention import FlashAttention
206
+
207
+ self.flash = FlashAttention()
208
+ # Set the scale for scaled dot-product attention.
209
+ self.flash.softmax_scale = self.scale
210
+ # Set to `None` if it's not installed
211
+ except ImportError:
212
+ self.flash = None
213
+
214
+ def forward(self, x, context=None, mask=None):
215
+ """
216
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
+ """
219
+
220
+ # If `cond` is `None` we perform self attention
221
+ has_cond = context is not None
222
+ if not has_cond:
223
+ context = x
224
+
225
+ # Get query, key and value vectors
226
+ q = self.to_q(x)
227
+ k = self.to_k(context)
228
+ v = self.to_v(context)
229
+
230
+ # Use flash attention if it's available and the head size is less than or equal to `128`
231
+ if (
232
+ CrossAttention.use_flash_attention
233
+ and self.flash is not None
234
+ and not has_cond
235
+ and self.d_head <= 128
236
+ ):
237
+ return self.flash_attention(q, k, v)
238
+ # Otherwise, fallback to normal attention
239
+ else:
240
+ return self.normal_attention(q, k, v)
241
+
242
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
+ """
244
+ #### Flash Attention
245
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ """
249
+
250
+ # Get batch size and number of elements along sequence axis (`width * height`)
251
+ batch_size, seq_len, _ = q.shape
252
+
253
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
+ qkv = torch.stack((q, k, v), dim=2)
256
+ # Split the heads
257
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
+
259
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
+ # fit this size.
261
+ if self.d_head <= 32:
262
+ pad = 32 - self.d_head
263
+ elif self.d_head <= 64:
264
+ pad = 64 - self.d_head
265
+ elif self.d_head <= 128:
266
+ pad = 128 - self.d_head
267
+ else:
268
+ raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
+
270
+ # Pad the heads
271
+ if pad:
272
+ qkv = torch.cat(
273
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
+ )
275
+
276
+ # Compute attention
277
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
+ # TODO here I add the dtype changing
280
+ out, _ = self.flash(qkv.type(torch.float16))
281
+ # Truncate the extra head size
282
+ out = out[:, :, :, : self.d_head].float()
283
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
+
286
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
287
+ return self.to_out(out)
288
+
289
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
+ """
291
+ #### Normal Attention
292
+
293
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
+ """
297
+
298
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
+ q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
+ k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
+ v = v.view(*v.shape[:2], self.n_heads, -1)
302
+
303
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
+ attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
+
306
+ # Compute softmax
307
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
+ if self.is_inplace:
309
+ half = attn.shape[0] // 2
310
+ attn[half:] = attn[half:].softmax(dim=-1)
311
+ attn[:half] = attn[:half].softmax(dim=-1)
312
+ else:
313
+ attn = attn.softmax(dim=-1)
314
+
315
+ # Compute attention output
316
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
+ # attn: [bs, 20, 64, 1]
318
+ # v: [bs, 1, 20, 32]
319
+ out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
+ out = out.reshape(*out.shape[:2], -1)
322
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
323
+ return self.to_out(out)
324
+
325
+
326
+ # class CrossAttention(nn.Module):
327
+ # def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
328
+ # super().__init__()
329
+ # inner_dim = dim_head * heads
330
+ # context_dim = default(context_dim, query_dim)
331
+
332
+ # self.scale = dim_head ** -0.5
333
+ # self.heads = heads
334
+
335
+ # self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
+ # self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
+ # self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
+
339
+ # self.to_out = nn.Sequential(
340
+ # nn.Linear(inner_dim, query_dim),
341
+ # nn.Dropout(dropout)
342
+ # )
343
+
344
+ # def forward(self, x, context=None, mask=None):
345
+ # h = self.heads
346
+
347
+ # q = self.to_q(x)
348
+ # context = default(context, x)
349
+ # k = self.to_k(context)
350
+ # v = self.to_v(context)
351
+
352
+ # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
353
+
354
+ # sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
355
+
356
+ # if exists(mask):
357
+ # mask = rearrange(mask, 'b ... -> b (...)')
358
+ # max_neg_value = -torch.finfo(sim.dtype).max
359
+ # mask = repeat(mask, 'b j -> (b h) () j', h=h)
360
+ # sim.masked_fill_(~mask, max_neg_value)
361
+
362
+ # # attention, what we cannot get enough of
363
+ # attn = sim.softmax(dim=-1)
364
+
365
+ # out = einsum('b i j, b j d -> b i d', attn, v)
366
+ # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
367
+ # return self.to_out(out)
368
+
369
+
370
+ class BasicTransformerBlock(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ n_heads,
375
+ d_head,
376
+ dropout=0.0,
377
+ context_dim=None,
378
+ gated_ff=True,
379
+ checkpoint=True,
380
+ ):
381
+ super().__init__()
382
+ self.attn1 = CrossAttention(
383
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
+ ) # is a self-attention
385
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
+ self.attn2 = CrossAttention(
387
+ query_dim=dim,
388
+ context_dim=context_dim,
389
+ heads=n_heads,
390
+ dim_head=d_head,
391
+ dropout=dropout,
392
+ ) # is self-attn if context is none
393
+ self.norm1 = nn.LayerNorm(dim)
394
+ self.norm2 = nn.LayerNorm(dim)
395
+ self.norm3 = nn.LayerNorm(dim)
396
+ self.checkpoint = checkpoint
397
+
398
+ def forward(self, x, context=None):
399
+ if context is None:
400
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
+ else:
402
+ return checkpoint(
403
+ self._forward, (x, context), self.parameters(), self.checkpoint
404
+ )
405
+
406
+ def _forward(self, x, context=None):
407
+ x = self.attn1(self.norm1(x)) + x
408
+ x = self.attn2(self.norm2(x), context=context) + x
409
+ x = self.ff(self.norm3(x)) + x
410
+ return x
411
+
412
+
413
+ class SpatialTransformer(nn.Module):
414
+ """
415
+ Transformer block for image-like data.
416
+ First, project the input (aka embedding)
417
+ and reshape to b, t, d.
418
+ Then apply standard transformer action.
419
+ Finally, reshape to image
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ n_heads,
426
+ d_head,
427
+ depth=1,
428
+ dropout=0.0,
429
+ context_dim=None,
430
+ no_context=False,
431
+ ):
432
+ super().__init__()
433
+
434
+ if no_context:
435
+ context_dim = None
436
+
437
+ self.in_channels = in_channels
438
+ inner_dim = n_heads * d_head
439
+ self.norm = Normalize(in_channels)
440
+
441
+ self.proj_in = nn.Conv2d(
442
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
443
+ )
444
+
445
+ self.transformer_blocks = nn.ModuleList(
446
+ [
447
+ BasicTransformerBlock(
448
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
449
+ )
450
+ for d in range(depth)
451
+ ]
452
+ )
453
+
454
+ self.proj_out = zero_module(
455
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
456
+ )
457
+
458
+ def forward(self, x, context=None):
459
+ # note: if no context is given, cross-attention defaults to self-attention
460
+ b, c, h, w = x.shape
461
+ x_in = x
462
+ x = self.norm(x)
463
+ x = self.proj_in(x)
464
+ x = rearrange(x, "b c h w -> b (h w) c")
465
+ for block in self.transformer_blocks:
466
+ x = block(x, context=context)
467
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
468
+ x = self.proj_out(x)
469
+ return x + x_in
audioldm/latent_diffusion/util.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ from einops import repeat
15
+
16
+ from audioldm.utils import instantiate_from_config
17
+
18
+
19
+ def make_beta_schedule(
20
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
21
+ ):
22
+ if schedule == "linear":
23
+ betas = (
24
+ torch.linspace(
25
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
26
+ )
27
+ ** 2
28
+ )
29
+
30
+ elif schedule == "cosine":
31
+ timesteps = (
32
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
33
+ )
34
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
35
+ alphas = torch.cos(alphas).pow(2)
36
+ alphas = alphas / alphas[0]
37
+ betas = 1 - alphas[1:] / alphas[:-1]
38
+ betas = np.clip(betas, a_min=0, a_max=0.999)
39
+
40
+ elif schedule == "sqrt_linear":
41
+ betas = torch.linspace(
42
+ linear_start, linear_end, n_timestep, dtype=torch.float64
43
+ )
44
+ elif schedule == "sqrt":
45
+ betas = (
46
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
47
+ ** 0.5
48
+ )
49
+ else:
50
+ raise ValueError(f"schedule '{schedule}' unknown.")
51
+ return betas.numpy()
52
+
53
+
54
+ def make_ddim_timesteps(
55
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
56
+ ):
57
+ if ddim_discr_method == "uniform":
58
+ c = num_ddpm_timesteps // num_ddim_timesteps
59
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
60
+ elif ddim_discr_method == "quad":
61
+ ddim_timesteps = (
62
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
63
+ ).astype(int)
64
+ else:
65
+ raise NotImplementedError(
66
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
67
+ )
68
+
69
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
70
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
71
+ steps_out = ddim_timesteps + 1
72
+ if verbose:
73
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
74
+ return steps_out
75
+
76
+
77
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
78
+ # select alphas for computing the variance schedule
79
+ alphas = alphacums[ddim_timesteps]
80
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
81
+
82
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
83
+ sigmas = eta * np.sqrt(
84
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
85
+ )
86
+ if verbose:
87
+ print(
88
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
89
+ )
90
+ print(
91
+ f"For the chosen value of eta, which is {eta}, "
92
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
93
+ )
94
+ return sigmas, alphas, alphas_prev
95
+
96
+
97
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
98
+ """
99
+ Create a beta schedule that discretizes the given alpha_t_bar function,
100
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
101
+ :param num_diffusion_timesteps: the number of betas to produce.
102
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
103
+ produces the cumulative product of (1-beta) up to that
104
+ part of the diffusion process.
105
+ :param max_beta: the maximum beta to use; use values lower than 1 to
106
+ prevent singularities.
107
+ """
108
+ betas = []
109
+ for i in range(num_diffusion_timesteps):
110
+ t1 = i / num_diffusion_timesteps
111
+ t2 = (i + 1) / num_diffusion_timesteps
112
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
113
+ return np.array(betas)
114
+
115
+
116
+ def extract_into_tensor(a, t, x_shape):
117
+ b, *_ = t.shape
118
+ out = a.gather(-1, t).contiguous()
119
+ return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
120
+
121
+
122
+ def checkpoint(func, inputs, params, flag):
123
+ """
124
+ Evaluate a function without caching intermediate activations, allowing for
125
+ reduced memory at the expense of extra compute in the backward pass.
126
+ :param func: the function to evaluate.
127
+ :param inputs: the argument sequence to pass to `func`.
128
+ :param params: a sequence of parameters `func` depends on but does not
129
+ explicitly take as arguments.
130
+ :param flag: if False, disable gradient checkpointing.
131
+ """
132
+ if flag:
133
+ args = tuple(inputs) + tuple(params)
134
+ return CheckpointFunction.apply(func, len(inputs), *args)
135
+ else:
136
+ return func(*inputs)
137
+
138
+
139
+ class CheckpointFunction(torch.autograd.Function):
140
+ @staticmethod
141
+ def forward(ctx, run_function, length, *args):
142
+ ctx.run_function = run_function
143
+ ctx.input_tensors = list(args[:length])
144
+ ctx.input_params = list(args[length:])
145
+
146
+ with torch.no_grad():
147
+ output_tensors = ctx.run_function(*ctx.input_tensors)
148
+ return output_tensors
149
+
150
+ @staticmethod
151
+ def backward(ctx, *output_grads):
152
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
153
+ with torch.enable_grad():
154
+ # Fixes a bug where the first op in run_function modifies the
155
+ # Tensor storage in place, which is not allowed for detach()'d
156
+ # Tensors.
157
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
158
+ output_tensors = ctx.run_function(*shallow_copies)
159
+ input_grads = torch.autograd.grad(
160
+ output_tensors,
161
+ ctx.input_tensors + ctx.input_params,
162
+ output_grads,
163
+ allow_unused=True,
164
+ )
165
+ del ctx.input_tensors
166
+ del ctx.input_params
167
+ del output_tensors
168
+ return (None, None) + input_grads
169
+
170
+
171
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
172
+ """
173
+ Create sinusoidal timestep embeddings.
174
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
175
+ These may be fractional.
176
+ :param dim: the dimension of the output.
177
+ :param max_period: controls the minimum frequency of the embeddings.
178
+ :return: an [N x dim] Tensor of positional embeddings.
179
+ """
180
+ if not repeat_only:
181
+ half = dim // 2
182
+ freqs = torch.exp(
183
+ -math.log(max_period)
184
+ * torch.arange(start=0, end=half, dtype=torch.float32)
185
+ / half
186
+ ).to(device=timesteps.device)
187
+ args = timesteps[:, None].float() * freqs[None]
188
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
189
+ if dim % 2:
190
+ embedding = torch.cat(
191
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
192
+ )
193
+ else:
194
+ embedding = repeat(timesteps, "b -> b d", d=dim)
195
+ return embedding
196
+
197
+
198
+ def zero_module(module):
199
+ """
200
+ Zero out the parameters of a module and return it.
201
+ """
202
+ for p in module.parameters():
203
+ p.detach().zero_()
204
+ return module
205
+
206
+
207
+ def scale_module(module, scale):
208
+ """
209
+ Scale the parameters of a module and return it.
210
+ """
211
+ for p in module.parameters():
212
+ p.detach().mul_(scale)
213
+ return module
214
+
215
+
216
+ def mean_flat(tensor):
217
+ """
218
+ Take the mean over all non-batch dimensions.
219
+ """
220
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
221
+
222
+
223
+ def normalization(channels):
224
+ """
225
+ Make a standard normalization layer.
226
+ :param channels: number of input channels.
227
+ :return: an nn.Module for normalization.
228
+ """
229
+ return GroupNorm32(32, channels)
230
+
231
+
232
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
233
+ class SiLU(nn.Module):
234
+ def forward(self, x):
235
+ return x * torch.sigmoid(x)
236
+
237
+
238
+ class GroupNorm32(nn.GroupNorm):
239
+ def forward(self, x):
240
+ return super().forward(x.float()).type(x.dtype)
241
+
242
+
243
+ def conv_nd(dims, *args, **kwargs):
244
+ """
245
+ Create a 1D, 2D, or 3D convolution module.
246
+ """
247
+ if dims == 1:
248
+ return nn.Conv1d(*args, **kwargs)
249
+ elif dims == 2:
250
+ return nn.Conv2d(*args, **kwargs)
251
+ elif dims == 3:
252
+ return nn.Conv3d(*args, **kwargs)
253
+ raise ValueError(f"unsupported dimensions: {dims}")
254
+
255
+
256
+ def linear(*args, **kwargs):
257
+ """
258
+ Create a linear module.
259
+ """
260
+ return nn.Linear(*args, **kwargs)
261
+
262
+
263
+ def avg_pool_nd(dims, *args, **kwargs):
264
+ """
265
+ Create a 1D, 2D, or 3D average pooling module.
266
+ """
267
+ if dims == 1:
268
+ return nn.AvgPool1d(*args, **kwargs)
269
+ elif dims == 2:
270
+ return nn.AvgPool2d(*args, **kwargs)
271
+ elif dims == 3:
272
+ return nn.AvgPool3d(*args, **kwargs)
273
+ raise ValueError(f"unsupported dimensions: {dims}")
274
+
275
+
276
+ class HybridConditioner(nn.Module):
277
+ def __init__(self, c_concat_config, c_crossattn_config):
278
+ super().__init__()
279
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
280
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
281
+
282
+ def forward(self, c_concat, c_crossattn):
283
+ c_concat = self.concat_conditioner(c_concat)
284
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
285
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
286
+
287
+
288
+ def noise_like(shape, device, repeat=False):
289
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
290
+ shape[0], *((1,) * (len(shape) - 1))
291
+ )
292
+ noise = lambda: torch.randn(shape, device=device)
293
+ return repeat_noise() if repeat else noise()
audioldm/stft.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from scipy.signal import get_window
5
+ from librosa.util import pad_center, tiny, normalize, pad_center
6
+ from librosa.filters import mel as librosa_mel_fn
7
+
8
+
9
+ def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
10
+ """
11
+ Parameters
12
+ ----------
13
+ C: compression factor
14
+ """
15
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
16
+
17
+
18
+ def dynamic_range_decompression(x, C=1):
19
+ """
20
+ Parameters
21
+ ----------
22
+ C: compression factor used to compress
23
+ """
24
+ return torch.exp(x) / C
25
+
26
+
27
+ def window_sumsquare(
28
+ window,
29
+ n_frames,
30
+ hop_length,
31
+ win_length,
32
+ n_fft,
33
+ dtype=np.float32,
34
+ norm=None,
35
+ ):
36
+ """
37
+ # from librosa 0.6
38
+ Compute the sum-square envelope of a window function at a given hop length.
39
+
40
+ This is used to estimate modulation effects induced by windowing
41
+ observations in short-time fourier transforms.
42
+
43
+ Parameters
44
+ ----------
45
+ window : string, tuple, number, callable, or list-like
46
+ Window specification, as in `get_window`
47
+
48
+ n_frames : int > 0
49
+ The number of analysis frames
50
+
51
+ hop_length : int > 0
52
+ The number of samples to advance between frames
53
+
54
+ win_length : [optional]
55
+ The length of the window function. By default, this matches `n_fft`.
56
+
57
+ n_fft : int > 0
58
+ The length of each analysis frame.
59
+
60
+ dtype : np.dtype
61
+ The data type of the output
62
+
63
+ Returns
64
+ -------
65
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
66
+ The sum-squared envelope of the window function
67
+ """
68
+ if win_length is None:
69
+ win_length = n_fft
70
+
71
+ n = n_fft + hop_length * (n_frames - 1)
72
+ x = np.zeros(n, dtype=dtype)
73
+
74
+ # Compute the squared window at the desired length
75
+ win_sq = get_window(window, win_length, fftbins=True)
76
+ win_sq = normalize(win_sq, norm=norm) ** 2
77
+ win_sq = pad_center(win_sq, n_fft)
78
+
79
+ # Fill the envelope
80
+ for i in range(n_frames):
81
+ sample = i * hop_length
82
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
83
+ return x
84
+
85
+
86
+ class STFT(torch.nn.Module):
87
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
88
+
89
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
90
+ super(STFT, self).__init__()
91
+ self.filter_length = filter_length
92
+ self.hop_length = hop_length
93
+ self.win_length = win_length
94
+ self.window = window
95
+ self.forward_transform = None
96
+ scale = self.filter_length / self.hop_length
97
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
98
+
99
+ cutoff = int((self.filter_length / 2 + 1))
100
+ fourier_basis = np.vstack(
101
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
102
+ )
103
+
104
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
105
+ inverse_basis = torch.FloatTensor(
106
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
107
+ )
108
+
109
+ if window is not None:
110
+ assert filter_length >= win_length
111
+ # get window and zero center pad it to filter_length
112
+ fft_window = get_window(window, win_length, fftbins=True)
113
+ fft_window = pad_center(fft_window, size=filter_length)
114
+ fft_window = torch.from_numpy(fft_window).float()
115
+
116
+ # window the bases
117
+ forward_basis *= fft_window
118
+ inverse_basis *= fft_window
119
+
120
+ self.register_buffer("forward_basis", forward_basis.float())
121
+ self.register_buffer("inverse_basis", inverse_basis.float())
122
+
123
+ def transform(self, input_data):
124
+ device = self.forward_basis.device
125
+ input_data = input_data.to(device)
126
+
127
+ num_batches = input_data.size(0)
128
+ num_samples = input_data.size(1)
129
+
130
+ self.num_samples = num_samples
131
+
132
+ # similar to librosa, reflect-pad the input
133
+ input_data = input_data.view(num_batches, 1, num_samples)
134
+ input_data = F.pad(
135
+ input_data.unsqueeze(1),
136
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
137
+ mode="reflect",
138
+ )
139
+ input_data = input_data.squeeze(1)
140
+
141
+ forward_transform = F.conv1d(
142
+ input_data,
143
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
144
+ stride=self.hop_length,
145
+ padding=0,
146
+ )
147
+
148
+ cutoff = int((self.filter_length / 2) + 1)
149
+ real_part = forward_transform[:, :cutoff, :]
150
+ imag_part = forward_transform[:, cutoff:, :]
151
+
152
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
153
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
154
+
155
+ return magnitude, phase
156
+
157
+ def inverse(self, magnitude, phase):
158
+ device = self.forward_basis.device
159
+ magnitude, phase = magnitude.to(device), phase.to(device)
160
+
161
+ recombine_magnitude_phase = torch.cat(
162
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
163
+ )
164
+
165
+ inverse_transform = F.conv_transpose1d(
166
+ recombine_magnitude_phase,
167
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
168
+ stride=self.hop_length,
169
+ padding=0,
170
+ )
171
+
172
+ if self.window is not None:
173
+ window_sum = window_sumsquare(
174
+ self.window,
175
+ magnitude.size(-1),
176
+ hop_length=self.hop_length,
177
+ win_length=self.win_length,
178
+ n_fft=self.filter_length,
179
+ dtype=np.float32,
180
+ )
181
+ # remove modulation effects
182
+ approx_nonzero_indices = torch.from_numpy(
183
+ np.where(window_sum > tiny(window_sum))[0]
184
+ )
185
+ window_sum = torch.autograd.Variable(
186
+ torch.from_numpy(window_sum), requires_grad=False
187
+ )
188
+ window_sum = window_sum
189
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
190
+ approx_nonzero_indices
191
+ ]
192
+
193
+ # scale by hop ratio
194
+ inverse_transform *= float(self.filter_length) / self.hop_length
195
+
196
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
197
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
198
+
199
+ return inverse_transform
200
+
201
+ def forward(self, input_data):
202
+ self.magnitude, self.phase = self.transform(input_data)
203
+ reconstruction = self.inverse(self.magnitude, self.phase)
204
+ return reconstruction
205
+
206
+
207
+ class TacotronSTFT(torch.nn.Module):
208
+ def __init__(
209
+ self,
210
+ filter_length,
211
+ hop_length,
212
+ win_length,
213
+ n_mel_channels,
214
+ sampling_rate,
215
+ mel_fmin,
216
+ mel_fmax,
217
+ ):
218
+ super(TacotronSTFT, self).__init__()
219
+ self.n_mel_channels = n_mel_channels
220
+ self.sampling_rate = sampling_rate
221
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
222
+ mel_basis = librosa_mel_fn(
223
+ sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
224
+ )
225
+ mel_basis = torch.from_numpy(mel_basis).float()
226
+ self.register_buffer("mel_basis", mel_basis)
227
+
228
+ def spectral_normalize(self, magnitudes, normalize_fun):
229
+ output = dynamic_range_compression(magnitudes, normalize_fun)
230
+ return output
231
+
232
+ def spectral_de_normalize(self, magnitudes):
233
+ output = dynamic_range_decompression(magnitudes)
234
+ return output
235
+
236
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
237
+ """Computes mel-spectrograms from a batch of waves
238
+ PARAMS
239
+ ------
240
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
241
+
242
+ RETURNS
243
+ -------
244
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
245
+ """
246
+ assert torch.min(y.data) >= -1, torch.min(y.data)
247
+ assert torch.max(y.data) <= 1, torch.max(y.data)
248
+
249
+ magnitudes, phases = self.stft_fn.transform(y)
250
+ magnitudes = magnitudes.data
251
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
252
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
253
+ energy = torch.norm(magnitudes, dim=1)
254
+
255
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
256
+
257
+ return mel_output, log_magnitudes, energy
audioldm/utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+
5
+ CACHE_DIR = os.getenv(
6
+ "AUDIOLDM_CACHE_DIR",
7
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
8
+
9
+
10
+ def default_audioldm_config(model_name="audioldm-s-full"):
11
+ basic_config = {
12
+ "wave_file_save_path": "./output",
13
+ "id": {
14
+ "version": "v1",
15
+ "name": "default",
16
+ "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
17
+ },
18
+ "preprocessing": {
19
+ "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
20
+ "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
21
+ "mel": {
22
+ "n_mel_channels": 64,
23
+ "mel_fmin": 0,
24
+ "mel_fmax": 8000,
25
+ "freqm": 0,
26
+ "timem": 0,
27
+ "blur": False,
28
+ "mean": -4.63,
29
+ "std": 2.74,
30
+ "target_length": 1024,
31
+ },
32
+ },
33
+ "model": {
34
+ "device": "cuda",
35
+ "target": "audioldm.pipline.LatentDiffusion",
36
+ "params": {
37
+ "base_learning_rate": 5e-06,
38
+ "linear_start": 0.0015,
39
+ "linear_end": 0.0195,
40
+ "num_timesteps_cond": 1,
41
+ "log_every_t": 200,
42
+ "timesteps": 1000,
43
+ "first_stage_key": "fbank",
44
+ "cond_stage_key": "waveform",
45
+ "latent_t_size": 256,
46
+ "latent_f_size": 16,
47
+ "channels": 8,
48
+ "cond_stage_trainable": True,
49
+ "conditioning_key": "film",
50
+ "monitor": "val/loss_simple_ema",
51
+ "scale_by_std": True,
52
+ "unet_config": {
53
+ "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
54
+ "params": {
55
+ "image_size": 64,
56
+ "extra_film_condition_dim": 512,
57
+ "extra_film_use_concat": True,
58
+ "in_channels": 8,
59
+ "out_channels": 8,
60
+ "model_channels": 128,
61
+ "attention_resolutions": [8, 4, 2],
62
+ "num_res_blocks": 2,
63
+ "channel_mult": [1, 2, 3, 5],
64
+ "num_head_channels": 32,
65
+ "use_spatial_transformer": True,
66
+ },
67
+ },
68
+ "first_stage_config": {
69
+ "base_learning_rate": 4.5e-05,
70
+ "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
71
+ "params": {
72
+ "monitor": "val/rec_loss",
73
+ "image_key": "fbank",
74
+ "subband": 1,
75
+ "embed_dim": 8,
76
+ "time_shuffle": 1,
77
+ "ddconfig": {
78
+ "double_z": True,
79
+ "z_channels": 8,
80
+ "resolution": 256,
81
+ "downsample_time": False,
82
+ "in_channels": 1,
83
+ "out_ch": 1,
84
+ "ch": 128,
85
+ "ch_mult": [1, 2, 4],
86
+ "num_res_blocks": 2,
87
+ "attn_resolutions": [],
88
+ "dropout": 0.0,
89
+ },
90
+ },
91
+ },
92
+ "cond_stage_config": {
93
+ "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
94
+ "params": {
95
+ "key": "waveform",
96
+ "sampling_rate": 16000,
97
+ "embed_mode": "audio",
98
+ "unconditional_prob": 0.1,
99
+ },
100
+ },
101
+ },
102
+ },
103
+ }
104
+
105
+ if("-l-" in model_name):
106
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
107
+ basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
108
+ elif("-m-" in model_name):
109
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
110
+ basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
111
+
112
+ return basic_config
113
+
114
+
115
+ def get_metadata():
116
+ return {
117
+ "audioldm-s-full": {
118
+ "path": os.path.join(
119
+ CACHE_DIR,
120
+ "audioldm-s-full.ckpt",
121
+ ),
122
+ "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
123
+ },
124
+ "audioldm-l-full": {
125
+ "path": os.path.join(
126
+ CACHE_DIR,
127
+ "audioldm-l-full.ckpt",
128
+ ),
129
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
130
+ },
131
+ "audioldm-s-full-v2": {
132
+ "path": os.path.join(
133
+ CACHE_DIR,
134
+ "audioldm-s-full-v2.ckpt",
135
+ ),
136
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
137
+ },
138
+ "audioldm-m-text-ft": {
139
+ "path": os.path.join(
140
+ CACHE_DIR,
141
+ "audioldm-m-text-ft.ckpt",
142
+ ),
143
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
144
+ },
145
+ "audioldm-s-text-ft": {
146
+ "path": os.path.join(
147
+ CACHE_DIR,
148
+ "audioldm-s-text-ft.ckpt",
149
+ ),
150
+ "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
151
+ },
152
+ "audioldm-m-full": {
153
+ "path": os.path.join(
154
+ CACHE_DIR,
155
+ "audioldm-m-full.ckpt",
156
+ ),
157
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
158
+ },
159
+ }
160
+
161
+
162
+ def get_obj_from_str(string, reload=False):
163
+ module, cls = string.rsplit(".", 1)
164
+ if reload:
165
+ module_imp = importlib.import_module(module)
166
+ importlib.reload(module_imp)
167
+ return getattr(importlib.import_module(module, package=None), cls)
168
+
169
+
170
+ def instantiate_from_config(config):
171
+ if not "target" in config:
172
+ if config == "__is_first_stage__":
173
+ return None
174
+ elif config == "__is_unconditional__":
175
+ return None
176
+ raise KeyError("Expected key `target` to instantiate.")
177
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
audioldm/variational_autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoencoderKL
audioldm/variational_autoencoder/autoencoder.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from audioldm.variational_autoencoder.modules import Encoder, Decoder
5
+ from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
6
+ from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
7
+
8
+
9
+ class AutoencoderKL(nn.Module):
10
+ def __init__(
11
+ self,
12
+ ddconfig=None,
13
+ lossconfig=None,
14
+ image_key="fbank",
15
+ embed_dim=None,
16
+ time_shuffle=1,
17
+ subband=1,
18
+ ckpt_path=None,
19
+ reload_from_ckpt=None,
20
+ ignore_keys=[],
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ base_learning_rate=1e-5,
24
+ scale_factor=1
25
+ ):
26
+ super().__init__()
27
+
28
+ self.encoder = Encoder(**ddconfig)
29
+ self.decoder = Decoder(**ddconfig)
30
+ self.ema_decoder = None
31
+
32
+ self.subband = int(subband)
33
+ if self.subband > 1:
34
+ print("Use subband decomposition %s" % self.subband)
35
+
36
+ self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
37
+ self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
38
+ self.ema_post_quant_conv = None
39
+
40
+ self.vocoder = get_vocoder(None, "cpu")
41
+ self.embed_dim = embed_dim
42
+
43
+ if monitor is not None:
44
+ self.monitor = monitor
45
+
46
+ self.time_shuffle = time_shuffle
47
+ self.reload_from_ckpt = reload_from_ckpt
48
+ self.reloaded = False
49
+ self.mean, self.std = None, None
50
+
51
+ self.scale_factor = scale_factor
52
+
53
+ @property
54
+ def device(self):
55
+ return next(self.parameters()).device
56
+
57
+ def freq_split_subband(self, fbank):
58
+ if self.subband == 1 or self.image_key != "stft":
59
+ return fbank
60
+
61
+ bs, ch, tstep, fbins = fbank.size()
62
+
63
+ assert fbank.size(-1) % self.subband == 0
64
+ assert ch == 1
65
+
66
+ return (
67
+ fbank.squeeze(1)
68
+ .reshape(bs, tstep, self.subband, fbins // self.subband)
69
+ .permute(0, 2, 1, 3)
70
+ )
71
+
72
+ def freq_merge_subband(self, subband_fbank):
73
+ if self.subband == 1 or self.image_key != "stft":
74
+ return subband_fbank
75
+ assert subband_fbank.size(1) == self.subband # Channel dimension
76
+ bs, sub_ch, tstep, fbins = subband_fbank.size()
77
+ return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
78
+
79
+ def encode(self, x):
80
+ x = self.freq_split_subband(x)
81
+ h = self.encoder(x)
82
+ moments = self.quant_conv(h)
83
+ posterior = DiagonalGaussianDistribution(moments)
84
+ return posterior
85
+
86
+ @torch.no_grad()
87
+ def encode_first_stage(self, x):
88
+ return self.encode(x)
89
+
90
+ def decode(self, z, use_ema=False):
91
+ if use_ema and (not hasattr(self, 'ema_decoder') or self.ema_decoder is None):
92
+ print("VAE does not have EMA modules, but specified use_ema. "
93
+ "Using the none-EMA modules instead.")
94
+ if use_ema and hasattr(self, 'ema_decoder') and self.ema_decoder is not None:
95
+ z = self.ema_post_quant_conv(z)
96
+ dec = self.ema_decoder(z)
97
+ else:
98
+ z = self.post_quant_conv(z)
99
+ dec = self.decoder(z)
100
+ return self.freq_merge_subband(dec)
101
+
102
+ def decode_first_stage(self, z, allow_grad=False, use_ema=False):
103
+ with torch.set_grad_enabled(allow_grad):
104
+ z = z / self.scale_factor
105
+ return self.decode(z, use_ema)
106
+
107
+ def decode_to_waveform(self, dec, allow_grad=False):
108
+ dec = dec.squeeze(1).permute(0, 2, 1)
109
+ wav_reconstruction = vocoder_infer(dec, self.vocoder, allow_grad=allow_grad)
110
+ return wav_reconstruction
111
+
112
+ def forward(self, input, sample_posterior=True):
113
+ posterior = self.encode(input)
114
+ z = posterior.sample() if sample_posterior else posterior.mode()
115
+
116
+ if self.flag_first_run:
117
+ print("Latent size: ", z.size())
118
+ self.flag_first_run = False
119
+
120
+ return self.decode(z), posterior
121
+
122
+ def get_first_stage_encoding(self, encoder_posterior):
123
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
124
+ z = encoder_posterior.sample()
125
+ elif isinstance(encoder_posterior, torch.Tensor):
126
+ z = encoder_posterior
127
+ else:
128
+ raise NotImplementedError(
129
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
130
+ )
131
+ return self.scale_factor * z
audioldm/variational_autoencoder/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.mean(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.mean(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
audioldm/variational_autoencoder/modules.py ADDED
@@ -0,0 +1,1067 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from audioldm.utils import instantiate_from_config
9
+ from audioldm.latent_diffusion.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x * torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(
40
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
41
+ )
42
+
43
+
44
+ class Upsample(nn.Module):
45
+ def __init__(self, in_channels, with_conv):
46
+ super().__init__()
47
+ self.with_conv = with_conv
48
+ if self.with_conv:
49
+ self.conv = torch.nn.Conv2d(
50
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
51
+ )
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class UpsampleTimeStride4(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ self.conv = torch.nn.Conv2d(
66
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
67
+ )
68
+
69
+ def forward(self, x):
70
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # Do time downsampling here
82
+ # no asymmetric padding in torch conv, must do it ourselves
83
+ self.conv = torch.nn.Conv2d(
84
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
85
+ )
86
+
87
+ def forward(self, x):
88
+ if self.with_conv:
89
+ pad = (0, 1, 0, 1)
90
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
91
+ x = self.conv(x)
92
+ else:
93
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
94
+ return x
95
+
96
+
97
+ class DownsampleTimeStride4(nn.Module):
98
+ def __init__(self, in_channels, with_conv):
99
+ super().__init__()
100
+ self.with_conv = with_conv
101
+ if self.with_conv:
102
+ # Do time downsampling here
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = torch.nn.Conv2d(
105
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
106
+ )
107
+
108
+ def forward(self, x):
109
+ if self.with_conv:
110
+ pad = (0, 1, 0, 1)
111
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
112
+ x = self.conv(x)
113
+ else:
114
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
115
+ return x
116
+
117
+
118
+ class ResnetBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ in_channels,
123
+ out_channels=None,
124
+ conv_shortcut=False,
125
+ dropout,
126
+ temb_channels=512,
127
+ ):
128
+ super().__init__()
129
+ self.in_channels = in_channels
130
+ out_channels = in_channels if out_channels is None else out_channels
131
+ self.out_channels = out_channels
132
+ self.use_conv_shortcut = conv_shortcut
133
+
134
+ self.norm1 = Normalize(in_channels)
135
+ self.conv1 = torch.nn.Conv2d(
136
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
137
+ )
138
+ if temb_channels > 0:
139
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
140
+ self.norm2 = Normalize(out_channels)
141
+ self.dropout = torch.nn.Dropout(dropout)
142
+ self.conv2 = torch.nn.Conv2d(
143
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
144
+ )
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = torch.nn.Conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
149
+ )
150
+ else:
151
+ self.nin_shortcut = torch.nn.Conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ def forward(self, x, temb):
156
+ h = x
157
+ h = self.norm1(h)
158
+ h = nonlinearity(h)
159
+ h = self.conv1(h)
160
+
161
+ if temb is not None:
162
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
163
+
164
+ h = self.norm2(h)
165
+ h = nonlinearity(h)
166
+ h = self.dropout(h)
167
+ h = self.conv2(h)
168
+
169
+ if self.in_channels != self.out_channels:
170
+ if self.use_conv_shortcut:
171
+ x = self.conv_shortcut(x)
172
+ else:
173
+ x = self.nin_shortcut(x)
174
+
175
+ return x + h
176
+
177
+
178
+ class LinAttnBlock(LinearAttention):
179
+ """to match AttnBlock usage"""
180
+
181
+ def __init__(self, in_channels):
182
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
183
+
184
+
185
+ class AttnBlock(nn.Module):
186
+ def __init__(self, in_channels):
187
+ super().__init__()
188
+ self.in_channels = in_channels
189
+
190
+ self.norm = Normalize(in_channels)
191
+ self.q = torch.nn.Conv2d(
192
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
193
+ )
194
+ self.k = torch.nn.Conv2d(
195
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
196
+ )
197
+ self.v = torch.nn.Conv2d(
198
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
199
+ )
200
+ self.proj_out = torch.nn.Conv2d(
201
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ h_ = x
206
+ h_ = self.norm(h_)
207
+ q = self.q(h_)
208
+ k = self.k(h_)
209
+ v = self.v(h_)
210
+
211
+ # compute attention
212
+ b, c, h, w = q.shape
213
+ q = q.reshape(b, c, h * w).contiguous()
214
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
215
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
216
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
217
+ w_ = w_ * (int(c) ** (-0.5))
218
+ w_ = torch.nn.functional.softmax(w_, dim=2)
219
+
220
+ # attend to values
221
+ v = v.reshape(b, c, h * w).contiguous()
222
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
223
+ h_ = torch.bmm(
224
+ v, w_
225
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
226
+ h_ = h_.reshape(b, c, h, w).contiguous()
227
+
228
+ h_ = self.proj_out(h_)
229
+
230
+ return x + h_
231
+
232
+
233
+ def make_attn(in_channels, attn_type="vanilla"):
234
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
235
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
236
+ if attn_type == "vanilla":
237
+ return AttnBlock(in_channels)
238
+ elif attn_type == "none":
239
+ return nn.Identity(in_channels)
240
+ else:
241
+ return LinAttnBlock(in_channels)
242
+
243
+
244
+ class Model(nn.Module):
245
+ def __init__(
246
+ self,
247
+ *,
248
+ ch,
249
+ out_ch,
250
+ ch_mult=(1, 2, 4, 8),
251
+ num_res_blocks,
252
+ attn_resolutions,
253
+ dropout=0.0,
254
+ resamp_with_conv=True,
255
+ in_channels,
256
+ resolution,
257
+ use_timestep=True,
258
+ use_linear_attn=False,
259
+ attn_type="vanilla",
260
+ ):
261
+ super().__init__()
262
+ if use_linear_attn:
263
+ attn_type = "linear"
264
+ self.ch = ch
265
+ self.temb_ch = self.ch * 4
266
+ self.num_resolutions = len(ch_mult)
267
+ self.num_res_blocks = num_res_blocks
268
+ self.resolution = resolution
269
+ self.in_channels = in_channels
270
+
271
+ self.use_timestep = use_timestep
272
+ if self.use_timestep:
273
+ # timestep embedding
274
+ self.temb = nn.Module()
275
+ self.temb.dense = nn.ModuleList(
276
+ [
277
+ torch.nn.Linear(self.ch, self.temb_ch),
278
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
279
+ ]
280
+ )
281
+
282
+ # downsampling
283
+ self.conv_in = torch.nn.Conv2d(
284
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
285
+ )
286
+
287
+ curr_res = resolution
288
+ in_ch_mult = (1,) + tuple(ch_mult)
289
+ self.down = nn.ModuleList()
290
+ for i_level in range(self.num_resolutions):
291
+ block = nn.ModuleList()
292
+ attn = nn.ModuleList()
293
+ block_in = ch * in_ch_mult[i_level]
294
+ block_out = ch * ch_mult[i_level]
295
+ for i_block in range(self.num_res_blocks):
296
+ block.append(
297
+ ResnetBlock(
298
+ in_channels=block_in,
299
+ out_channels=block_out,
300
+ temb_channels=self.temb_ch,
301
+ dropout=dropout,
302
+ )
303
+ )
304
+ block_in = block_out
305
+ if curr_res in attn_resolutions:
306
+ attn.append(make_attn(block_in, attn_type=attn_type))
307
+ down = nn.Module()
308
+ down.block = block
309
+ down.attn = attn
310
+ if i_level != self.num_resolutions - 1:
311
+ down.downsample = Downsample(block_in, resamp_with_conv)
312
+ curr_res = curr_res // 2
313
+ self.down.append(down)
314
+
315
+ # middle
316
+ self.mid = nn.Module()
317
+ self.mid.block_1 = ResnetBlock(
318
+ in_channels=block_in,
319
+ out_channels=block_in,
320
+ temb_channels=self.temb_ch,
321
+ dropout=dropout,
322
+ )
323
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
324
+ self.mid.block_2 = ResnetBlock(
325
+ in_channels=block_in,
326
+ out_channels=block_in,
327
+ temb_channels=self.temb_ch,
328
+ dropout=dropout,
329
+ )
330
+
331
+ # upsampling
332
+ self.up = nn.ModuleList()
333
+ for i_level in reversed(range(self.num_resolutions)):
334
+ block = nn.ModuleList()
335
+ attn = nn.ModuleList()
336
+ block_out = ch * ch_mult[i_level]
337
+ skip_in = ch * ch_mult[i_level]
338
+ for i_block in range(self.num_res_blocks + 1):
339
+ if i_block == self.num_res_blocks:
340
+ skip_in = ch * in_ch_mult[i_level]
341
+ block.append(
342
+ ResnetBlock(
343
+ in_channels=block_in + skip_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ )
348
+ )
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(make_attn(block_in, attn_type=attn_type))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in)
362
+ self.conv_out = torch.nn.Conv2d(
363
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
364
+ )
365
+
366
+ def forward(self, x, t=None, context=None):
367
+ # assert x.shape[2] == x.shape[3] == self.resolution
368
+ if context is not None:
369
+ # assume aligned context, cat along channel axis
370
+ x = torch.cat((x, context), dim=1)
371
+ if self.use_timestep:
372
+ # timestep embedding
373
+ assert t is not None
374
+ temb = get_timestep_embedding(t, self.ch)
375
+ temb = self.temb.dense[0](temb)
376
+ temb = nonlinearity(temb)
377
+ temb = self.temb.dense[1](temb)
378
+ else:
379
+ temb = None
380
+
381
+ # downsampling
382
+ hs = [self.conv_in(x)]
383
+ for i_level in range(self.num_resolutions):
384
+ for i_block in range(self.num_res_blocks):
385
+ h = self.down[i_level].block[i_block](hs[-1], temb)
386
+ if len(self.down[i_level].attn) > 0:
387
+ h = self.down[i_level].attn[i_block](h)
388
+ hs.append(h)
389
+ if i_level != self.num_resolutions - 1:
390
+ hs.append(self.down[i_level].downsample(hs[-1]))
391
+
392
+ # middle
393
+ h = hs[-1]
394
+ h = self.mid.block_1(h, temb)
395
+ h = self.mid.attn_1(h)
396
+ h = self.mid.block_2(h, temb)
397
+
398
+ # upsampling
399
+ for i_level in reversed(range(self.num_resolutions)):
400
+ for i_block in range(self.num_res_blocks + 1):
401
+ h = self.up[i_level].block[i_block](
402
+ torch.cat([h, hs.pop()], dim=1), temb
403
+ )
404
+ if len(self.up[i_level].attn) > 0:
405
+ h = self.up[i_level].attn[i_block](h)
406
+ if i_level != 0:
407
+ h = self.up[i_level].upsample(h)
408
+
409
+ # end
410
+ h = self.norm_out(h)
411
+ h = nonlinearity(h)
412
+ h = self.conv_out(h)
413
+ return h
414
+
415
+ def get_last_layer(self):
416
+ return self.conv_out.weight
417
+
418
+
419
+ class Encoder(nn.Module):
420
+ def __init__(
421
+ self,
422
+ *,
423
+ ch,
424
+ out_ch,
425
+ ch_mult=(1, 2, 4, 8),
426
+ num_res_blocks,
427
+ attn_resolutions,
428
+ dropout=0.0,
429
+ resamp_with_conv=True,
430
+ in_channels,
431
+ resolution,
432
+ z_channels,
433
+ double_z=True,
434
+ use_linear_attn=False,
435
+ attn_type="vanilla",
436
+ downsample_time_stride4_levels=[],
437
+ **ignore_kwargs,
438
+ ):
439
+ super().__init__()
440
+ if use_linear_attn:
441
+ attn_type = "linear"
442
+ self.ch = ch
443
+ self.temb_ch = 0
444
+ self.num_resolutions = len(ch_mult)
445
+ self.num_res_blocks = num_res_blocks
446
+ self.resolution = resolution
447
+ self.in_channels = in_channels
448
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
449
+
450
+ if len(self.downsample_time_stride4_levels) > 0:
451
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
452
+ "The level to perform downsample 4 operation need to be smaller than "
453
+ "the total resolution number %s" % str(self.num_resolutions)
454
+ )
455
+
456
+ # downsampling
457
+ self.conv_in = torch.nn.Conv2d(
458
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
459
+ )
460
+
461
+ curr_res = resolution
462
+ in_ch_mult = (1,) + tuple(ch_mult)
463
+ self.in_ch_mult = in_ch_mult
464
+ self.down = nn.ModuleList()
465
+ for i_level in range(self.num_resolutions):
466
+ block = nn.ModuleList()
467
+ attn = nn.ModuleList()
468
+ block_in = ch * in_ch_mult[i_level]
469
+ block_out = ch * ch_mult[i_level]
470
+ for i_block in range(self.num_res_blocks):
471
+ block.append(
472
+ ResnetBlock(
473
+ in_channels=block_in,
474
+ out_channels=block_out,
475
+ temb_channels=self.temb_ch,
476
+ dropout=dropout,
477
+ )
478
+ )
479
+ block_in = block_out
480
+ if curr_res in attn_resolutions:
481
+ attn.append(make_attn(block_in, attn_type=attn_type))
482
+ down = nn.Module()
483
+ down.block = block
484
+ down.attn = attn
485
+ if i_level != self.num_resolutions - 1:
486
+ if i_level in self.downsample_time_stride4_levels:
487
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
488
+ else:
489
+ down.downsample = Downsample(block_in, resamp_with_conv)
490
+ curr_res = curr_res // 2
491
+ self.down.append(down)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(
496
+ in_channels=block_in,
497
+ out_channels=block_in,
498
+ temb_channels=self.temb_ch,
499
+ dropout=dropout,
500
+ )
501
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
502
+ self.mid.block_2 = ResnetBlock(
503
+ in_channels=block_in,
504
+ out_channels=block_in,
505
+ temb_channels=self.temb_ch,
506
+ dropout=dropout,
507
+ )
508
+
509
+ # end
510
+ self.norm_out = Normalize(block_in)
511
+ self.conv_out = torch.nn.Conv2d(
512
+ block_in,
513
+ 2 * z_channels if double_z else z_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1,
517
+ )
518
+
519
+ def forward(self, x):
520
+ # timestep embedding
521
+ temb = None
522
+ # downsampling
523
+ hs = [self.conv_in(x)]
524
+ for i_level in range(self.num_resolutions):
525
+ for i_block in range(self.num_res_blocks):
526
+ h = self.down[i_level].block[i_block](hs[-1], temb)
527
+ if len(self.down[i_level].attn) > 0:
528
+ h = self.down[i_level].attn[i_block](h)
529
+ hs.append(h)
530
+ if i_level != self.num_resolutions - 1:
531
+ hs.append(self.down[i_level].downsample(hs[-1]))
532
+
533
+ # middle
534
+ h = hs[-1]
535
+ h = self.mid.block_1(h, temb)
536
+ h = self.mid.attn_1(h)
537
+ h = self.mid.block_2(h, temb)
538
+
539
+ # end
540
+ h = self.norm_out(h)
541
+ h = nonlinearity(h)
542
+ h = self.conv_out(h)
543
+ return h
544
+
545
+
546
+ class Decoder(nn.Module):
547
+ def __init__(
548
+ self,
549
+ *,
550
+ ch,
551
+ out_ch,
552
+ ch_mult=(1, 2, 4, 8),
553
+ num_res_blocks,
554
+ attn_resolutions,
555
+ dropout=0.0,
556
+ resamp_with_conv=True,
557
+ in_channels,
558
+ resolution,
559
+ z_channels,
560
+ give_pre_end=False,
561
+ tanh_out=False,
562
+ use_linear_attn=False,
563
+ downsample_time_stride4_levels=[],
564
+ attn_type="vanilla",
565
+ **ignorekwargs,
566
+ ):
567
+ super().__init__()
568
+ if use_linear_attn:
569
+ attn_type = "linear"
570
+ self.ch = ch
571
+ self.temb_ch = 0
572
+ self.num_resolutions = len(ch_mult)
573
+ self.num_res_blocks = num_res_blocks
574
+ self.resolution = resolution
575
+ self.in_channels = in_channels
576
+ self.give_pre_end = give_pre_end
577
+ self.tanh_out = tanh_out
578
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
579
+
580
+ if len(self.downsample_time_stride4_levels) > 0:
581
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
582
+ "The level to perform downsample 4 operation need to be smaller than "
583
+ "the total resolution number %s" % str(self.num_resolutions)
584
+ )
585
+
586
+ # compute in_ch_mult, block_in and curr_res at lowest res
587
+ in_ch_mult = (1,) + tuple(ch_mult)
588
+ block_in = ch * ch_mult[self.num_resolutions - 1]
589
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
590
+ self.z_shape = (1, z_channels, curr_res, curr_res)
591
+ # print("Working with z of shape {} = {} dimensions.".format(
592
+ # self.z_shape, np.prod(self.z_shape)))
593
+
594
+ # z to block_in
595
+ self.conv_in = torch.nn.Conv2d(
596
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
597
+ )
598
+
599
+ # middle
600
+ self.mid = nn.Module()
601
+ self.mid.block_1 = ResnetBlock(
602
+ in_channels=block_in,
603
+ out_channels=block_in,
604
+ temb_channels=self.temb_ch,
605
+ dropout=dropout,
606
+ )
607
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
608
+ self.mid.block_2 = ResnetBlock(
609
+ in_channels=block_in,
610
+ out_channels=block_in,
611
+ temb_channels=self.temb_ch,
612
+ dropout=dropout,
613
+ )
614
+
615
+ # upsampling
616
+ self.up = nn.ModuleList()
617
+ for i_level in reversed(range(self.num_resolutions)):
618
+ block = nn.ModuleList()
619
+ attn = nn.ModuleList()
620
+ block_out = ch * ch_mult[i_level]
621
+ for i_block in range(self.num_res_blocks + 1):
622
+ block.append(
623
+ ResnetBlock(
624
+ in_channels=block_in,
625
+ out_channels=block_out,
626
+ temb_channels=self.temb_ch,
627
+ dropout=dropout,
628
+ )
629
+ )
630
+ block_in = block_out
631
+ if curr_res in attn_resolutions:
632
+ attn.append(make_attn(block_in, attn_type=attn_type))
633
+ up = nn.Module()
634
+ up.block = block
635
+ up.attn = attn
636
+ if i_level != 0:
637
+ if i_level - 1 in self.downsample_time_stride4_levels:
638
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
639
+ else:
640
+ up.upsample = Upsample(block_in, resamp_with_conv)
641
+ curr_res = curr_res * 2
642
+ self.up.insert(0, up) # prepend to get consistent order
643
+
644
+ # end
645
+ self.norm_out = Normalize(block_in)
646
+ self.conv_out = torch.nn.Conv2d(
647
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
648
+ )
649
+
650
+ def forward(self, z):
651
+ # assert z.shape[1:] == self.z_shape[1:]
652
+ self.last_z_shape = z.shape
653
+
654
+ # timestep embedding
655
+ temb = None
656
+
657
+ # z to block_in
658
+ h = self.conv_in(z)
659
+
660
+ # middle
661
+ h = self.mid.block_1(h, temb)
662
+ h = self.mid.attn_1(h)
663
+ h = self.mid.block_2(h, temb)
664
+
665
+ # upsampling
666
+ for i_level in reversed(range(self.num_resolutions)):
667
+ for i_block in range(self.num_res_blocks + 1):
668
+ h = self.up[i_level].block[i_block](h.float(), temb)
669
+ if len(self.up[i_level].attn) > 0:
670
+ h = self.up[i_level].attn[i_block](h.float())
671
+ if i_level != 0:
672
+ h = self.up[i_level].upsample(h.float())
673
+
674
+ # end
675
+ if self.give_pre_end:
676
+ return h
677
+
678
+ h = self.norm_out(h)
679
+ h = nonlinearity(h)
680
+ h = self.conv_out(h)
681
+ if self.tanh_out:
682
+ h = torch.tanh(h)
683
+ return h
684
+
685
+
686
+ class SimpleDecoder(nn.Module):
687
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
688
+ super().__init__()
689
+ self.model = nn.ModuleList(
690
+ [
691
+ nn.Conv2d(in_channels, in_channels, 1),
692
+ ResnetBlock(
693
+ in_channels=in_channels,
694
+ out_channels=2 * in_channels,
695
+ temb_channels=0,
696
+ dropout=0.0,
697
+ ),
698
+ ResnetBlock(
699
+ in_channels=2 * in_channels,
700
+ out_channels=4 * in_channels,
701
+ temb_channels=0,
702
+ dropout=0.0,
703
+ ),
704
+ ResnetBlock(
705
+ in_channels=4 * in_channels,
706
+ out_channels=2 * in_channels,
707
+ temb_channels=0,
708
+ dropout=0.0,
709
+ ),
710
+ nn.Conv2d(2 * in_channels, in_channels, 1),
711
+ Upsample(in_channels, with_conv=True),
712
+ ]
713
+ )
714
+ # end
715
+ self.norm_out = Normalize(in_channels)
716
+ self.conv_out = torch.nn.Conv2d(
717
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
718
+ )
719
+
720
+ def forward(self, x):
721
+ for i, layer in enumerate(self.model):
722
+ if i in [1, 2, 3]:
723
+ x = layer(x, None)
724
+ else:
725
+ x = layer(x)
726
+
727
+ h = self.norm_out(x)
728
+ h = nonlinearity(h)
729
+ x = self.conv_out(h)
730
+ return x
731
+
732
+
733
+ class UpsampleDecoder(nn.Module):
734
+ def __init__(
735
+ self,
736
+ in_channels,
737
+ out_channels,
738
+ ch,
739
+ num_res_blocks,
740
+ resolution,
741
+ ch_mult=(2, 2),
742
+ dropout=0.0,
743
+ ):
744
+ super().__init__()
745
+ # upsampling
746
+ self.temb_ch = 0
747
+ self.num_resolutions = len(ch_mult)
748
+ self.num_res_blocks = num_res_blocks
749
+ block_in = in_channels
750
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
751
+ self.res_blocks = nn.ModuleList()
752
+ self.upsample_blocks = nn.ModuleList()
753
+ for i_level in range(self.num_resolutions):
754
+ res_block = []
755
+ block_out = ch * ch_mult[i_level]
756
+ for _ in range(self.num_res_blocks + 1):
757
+ res_block.append(
758
+ ResnetBlock(
759
+ in_channels=block_in,
760
+ out_channels=block_out,
761
+ temb_channels=self.temb_ch,
762
+ dropout=dropout,
763
+ )
764
+ )
765
+ block_in = block_out
766
+ self.res_blocks.append(nn.ModuleList(res_block))
767
+ if i_level != self.num_resolutions - 1:
768
+ self.upsample_blocks.append(Upsample(block_in, True))
769
+ curr_res = curr_res * 2
770
+
771
+ # end
772
+ self.norm_out = Normalize(block_in)
773
+ self.conv_out = torch.nn.Conv2d(
774
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
775
+ )
776
+
777
+ def forward(self, x):
778
+ # upsampling
779
+ h = x
780
+ for k, i_level in enumerate(range(self.num_resolutions)):
781
+ for i_block in range(self.num_res_blocks + 1):
782
+ h = self.res_blocks[i_level][i_block](h, None)
783
+ if i_level != self.num_resolutions - 1:
784
+ h = self.upsample_blocks[k](h)
785
+ h = self.norm_out(h)
786
+ h = nonlinearity(h)
787
+ h = self.conv_out(h)
788
+ return h
789
+
790
+
791
+ class LatentRescaler(nn.Module):
792
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
793
+ super().__init__()
794
+ # residual block, interpolate, residual block
795
+ self.factor = factor
796
+ self.conv_in = nn.Conv2d(
797
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
798
+ )
799
+ self.res_block1 = nn.ModuleList(
800
+ [
801
+ ResnetBlock(
802
+ in_channels=mid_channels,
803
+ out_channels=mid_channels,
804
+ temb_channels=0,
805
+ dropout=0.0,
806
+ )
807
+ for _ in range(depth)
808
+ ]
809
+ )
810
+ self.attn = AttnBlock(mid_channels)
811
+ self.res_block2 = nn.ModuleList(
812
+ [
813
+ ResnetBlock(
814
+ in_channels=mid_channels,
815
+ out_channels=mid_channels,
816
+ temb_channels=0,
817
+ dropout=0.0,
818
+ )
819
+ for _ in range(depth)
820
+ ]
821
+ )
822
+
823
+ self.conv_out = nn.Conv2d(
824
+ mid_channels,
825
+ out_channels,
826
+ kernel_size=1,
827
+ )
828
+
829
+ def forward(self, x):
830
+ x = self.conv_in(x)
831
+ for block in self.res_block1:
832
+ x = block(x, None)
833
+ x = torch.nn.functional.interpolate(
834
+ x,
835
+ size=(
836
+ int(round(x.shape[2] * self.factor)),
837
+ int(round(x.shape[3] * self.factor)),
838
+ ),
839
+ )
840
+ x = self.attn(x).contiguous()
841
+ for block in self.res_block2:
842
+ x = block(x, None)
843
+ x = self.conv_out(x)
844
+ return x
845
+
846
+
847
+ class MergedRescaleEncoder(nn.Module):
848
+ def __init__(
849
+ self,
850
+ in_channels,
851
+ ch,
852
+ resolution,
853
+ out_ch,
854
+ num_res_blocks,
855
+ attn_resolutions,
856
+ dropout=0.0,
857
+ resamp_with_conv=True,
858
+ ch_mult=(1, 2, 4, 8),
859
+ rescale_factor=1.0,
860
+ rescale_module_depth=1,
861
+ ):
862
+ super().__init__()
863
+ intermediate_chn = ch * ch_mult[-1]
864
+ self.encoder = Encoder(
865
+ in_channels=in_channels,
866
+ num_res_blocks=num_res_blocks,
867
+ ch=ch,
868
+ ch_mult=ch_mult,
869
+ z_channels=intermediate_chn,
870
+ double_z=False,
871
+ resolution=resolution,
872
+ attn_resolutions=attn_resolutions,
873
+ dropout=dropout,
874
+ resamp_with_conv=resamp_with_conv,
875
+ out_ch=None,
876
+ )
877
+ self.rescaler = LatentRescaler(
878
+ factor=rescale_factor,
879
+ in_channels=intermediate_chn,
880
+ mid_channels=intermediate_chn,
881
+ out_channels=out_ch,
882
+ depth=rescale_module_depth,
883
+ )
884
+
885
+ def forward(self, x):
886
+ x = self.encoder(x)
887
+ x = self.rescaler(x)
888
+ return x
889
+
890
+
891
+ class MergedRescaleDecoder(nn.Module):
892
+ def __init__(
893
+ self,
894
+ z_channels,
895
+ out_ch,
896
+ resolution,
897
+ num_res_blocks,
898
+ attn_resolutions,
899
+ ch,
900
+ ch_mult=(1, 2, 4, 8),
901
+ dropout=0.0,
902
+ resamp_with_conv=True,
903
+ rescale_factor=1.0,
904
+ rescale_module_depth=1,
905
+ ):
906
+ super().__init__()
907
+ tmp_chn = z_channels * ch_mult[-1]
908
+ self.decoder = Decoder(
909
+ out_ch=out_ch,
910
+ z_channels=tmp_chn,
911
+ attn_resolutions=attn_resolutions,
912
+ dropout=dropout,
913
+ resamp_with_conv=resamp_with_conv,
914
+ in_channels=None,
915
+ num_res_blocks=num_res_blocks,
916
+ ch_mult=ch_mult,
917
+ resolution=resolution,
918
+ ch=ch,
919
+ )
920
+ self.rescaler = LatentRescaler(
921
+ factor=rescale_factor,
922
+ in_channels=z_channels,
923
+ mid_channels=tmp_chn,
924
+ out_channels=tmp_chn,
925
+ depth=rescale_module_depth,
926
+ )
927
+
928
+ def forward(self, x):
929
+ x = self.rescaler(x)
930
+ x = self.decoder(x)
931
+ return x
932
+
933
+
934
+ class Upsampler(nn.Module):
935
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
936
+ super().__init__()
937
+ assert out_size >= in_size
938
+ num_blocks = int(np.log2(out_size // in_size)) + 1
939
+ factor_up = 1.0 + (out_size % in_size)
940
+ print(
941
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
942
+ )
943
+ self.rescaler = LatentRescaler(
944
+ factor=factor_up,
945
+ in_channels=in_channels,
946
+ mid_channels=2 * in_channels,
947
+ out_channels=in_channels,
948
+ )
949
+ self.decoder = Decoder(
950
+ out_ch=out_channels,
951
+ resolution=out_size,
952
+ z_channels=in_channels,
953
+ num_res_blocks=2,
954
+ attn_resolutions=[],
955
+ in_channels=None,
956
+ ch=in_channels,
957
+ ch_mult=[ch_mult for _ in range(num_blocks)],
958
+ )
959
+
960
+ def forward(self, x):
961
+ x = self.rescaler(x)
962
+ x = self.decoder(x)
963
+ return x
964
+
965
+
966
+ class Resize(nn.Module):
967
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
968
+ super().__init__()
969
+ self.with_conv = learned
970
+ self.mode = mode
971
+ if self.with_conv:
972
+ print(
973
+ f"Note: {self.__class__.__name} uses learned downsampling "
974
+ f"and will ignore the fixed {mode} mode"
975
+ )
976
+ raise NotImplementedError()
977
+ assert in_channels is not None
978
+ # no asymmetric padding in torch conv, must do it ourselves
979
+ self.conv = torch.nn.Conv2d(
980
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
981
+ )
982
+
983
+ def forward(self, x, scale_factor=1.0):
984
+ if scale_factor == 1.0:
985
+ return x
986
+ else:
987
+ x = torch.nn.functional.interpolate(
988
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
989
+ )
990
+ return x
991
+
992
+
993
+ class FirstStagePostProcessor(nn.Module):
994
+ def __init__(
995
+ self,
996
+ ch_mult: list,
997
+ in_channels,
998
+ pretrained_model: nn.Module = None,
999
+ reshape=False,
1000
+ n_channels=None,
1001
+ dropout=0.0,
1002
+ pretrained_config=None,
1003
+ ):
1004
+ super().__init__()
1005
+ if pretrained_config is None:
1006
+ assert (
1007
+ pretrained_model is not None
1008
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1009
+ self.pretrained_model = pretrained_model
1010
+ else:
1011
+ assert (
1012
+ pretrained_config is not None
1013
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1014
+ self.instantiate_pretrained(pretrained_config)
1015
+
1016
+ self.do_reshape = reshape
1017
+
1018
+ if n_channels is None:
1019
+ n_channels = self.pretrained_model.encoder.ch
1020
+
1021
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
1022
+ self.proj = nn.Conv2d(
1023
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
1024
+ )
1025
+
1026
+ blocks = []
1027
+ downs = []
1028
+ ch_in = n_channels
1029
+ for m in ch_mult:
1030
+ blocks.append(
1031
+ ResnetBlock(
1032
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
1033
+ )
1034
+ )
1035
+ ch_in = m * n_channels
1036
+ downs.append(Downsample(ch_in, with_conv=False))
1037
+
1038
+ self.model = nn.ModuleList(blocks)
1039
+ self.downsampler = nn.ModuleList(downs)
1040
+
1041
+ def instantiate_pretrained(self, config):
1042
+ model = instantiate_from_config(config)
1043
+ self.pretrained_model = model.eval()
1044
+ # self.pretrained_model.train = False
1045
+ for param in self.pretrained_model.parameters():
1046
+ param.requires_grad = False
1047
+
1048
+ @torch.no_grad()
1049
+ def encode_with_pretrained(self, x):
1050
+ c = self.pretrained_model.encode(x)
1051
+ if isinstance(c, DiagonalGaussianDistribution):
1052
+ c = c.mode()
1053
+ return c
1054
+
1055
+ def forward(self, x):
1056
+ z_fs = self.encode_with_pretrained(x)
1057
+ z = self.proj_norm(z_fs)
1058
+ z = self.proj(z)
1059
+ z = nonlinearity(z)
1060
+
1061
+ for submodel, downmodel in zip(self.model, self.downsampler):
1062
+ z = submodel(z, temb=None)
1063
+ z = downmodel(z)
1064
+
1065
+ if self.do_reshape:
1066
+ z = rearrange(z, "b c h w -> b (h w) c")
1067
+ return z
consistencytta.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from transformers import AutoTokenizer, T5EncoderModel
4
+
5
+ from diffusers.utils.torch_utils import randn_tensor
6
+ from diffusers import UNet2DConditionGuidedModel, HeunDiscreteScheduler
7
+ from audioldm.stft import TacotronSTFT
8
+ from audioldm.variational_autoencoder import AutoencoderKL
9
+ from audioldm.utils import default_audioldm_config
10
+
11
+
12
+ class ConsistencyTTA(nn.Module):
13
+
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ # Initialize the consistency U-Net
18
+ unet_model_config_path='tango_diffusion_light.json'
19
+ unet_config = UNet2DConditionGuidedModel.load_config(unet_model_config_path)
20
+ self.unet = UNet2DConditionGuidedModel.from_config(unet_config, subfolder="unet")
21
+
22
+ unet_weight_path = "consistencytta_clapft_ckpt/unet_state_dict.pt"
23
+ unet_weight_sd = torch.load(unet_weight_path, map_location='cpu')
24
+ self.unet.load_state_dict(unet_weight_sd)
25
+
26
+ # Initialize FLAN-T5 tokenizer and text encoder
27
+ text_encoder_name = 'google/flan-t5-large'
28
+ self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
29
+ self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name)
30
+ self.text_encoder.eval(); self.text_encoder.requires_grad_(False)
31
+
32
+ # Initialize the VAE
33
+ raw_vae_path = "consistencytta_clapft_ckpt/vae_state_dict.pt"
34
+ raw_vae_sd = torch.load(raw_vae_path, map_location="cpu")
35
+ vae_state_dict, scale_factor = raw_vae_sd["state_dict"], raw_vae_sd["scale_factor"]
36
+
37
+ config = default_audioldm_config('audioldm-s-full')
38
+ vae_config = config["model"]["params"]["first_stage_config"]["params"]
39
+ vae_config["scale_factor"] = scale_factor
40
+
41
+ self.vae = AutoencoderKL(**vae_config)
42
+ self.vae.load_state_dict(vae_state_dict)
43
+ self.vae.eval(); self.vae.requires_grad_(False)
44
+
45
+ # Initialize the STFT
46
+ self.fn_STFT = TacotronSTFT(
47
+ config["preprocessing"]["stft"]["filter_length"], # default 1024
48
+ config["preprocessing"]["stft"]["hop_length"], # default 160
49
+ config["preprocessing"]["stft"]["win_length"], # default 1024
50
+ config["preprocessing"]["mel"]["n_mel_channels"], # default 64
51
+ config["preprocessing"]["audio"]["sampling_rate"], # default 16000
52
+ config["preprocessing"]["mel"]["mel_fmin"], # default 0
53
+ config["preprocessing"]["mel"]["mel_fmax"], # default 8000
54
+ )
55
+ self.fn_STFT.eval(); self.fn_STFT.requires_grad_(False)
56
+
57
+ self.scheduler = HeunDiscreteScheduler.from_pretrained(
58
+ pretrained_model_name_or_path='stabilityai/stable-diffusion-2-1', subfolder="scheduler"
59
+ )
60
+
61
+
62
+ def train(self, mode: bool = True):
63
+ self.unet.train(mode)
64
+ for model in [self.text_encoder, self.vae, self.fn_STFT]:
65
+ model.eval()
66
+ return self
67
+
68
+
69
+ def eval(self):
70
+ return self.train(mode=False)
71
+
72
+
73
+ def check_eval_mode(self):
74
+ for model, name in zip(
75
+ [self.text_encoder, self.vae, self.fn_STFT, self.unet],
76
+ ['text_encoder', 'vae', 'fn_STFT', 'unet']
77
+ ):
78
+ assert model.training == False, f"The {name} is not in eval mode."
79
+ for param in model.parameters():
80
+ assert param.requires_grad == False, f"The {name} is not frozen."
81
+
82
+
83
+ @torch.no_grad()
84
+ def encode_text(self, prompt, max_length=None, padding=True):
85
+ device = self.text_encoder.device
86
+ if max_length is None:
87
+ max_length = self.tokenizer.model_max_length
88
+
89
+ batch = self.tokenizer(
90
+ prompt, max_length=max_length, padding=padding,
91
+ truncation=True, return_tensors="pt"
92
+ )
93
+ input_ids = batch.input_ids.to(device)
94
+ attention_mask = batch.attention_mask.to(device)
95
+
96
+ prompt_embeds = self.text_encoder(
97
+ input_ids=input_ids, attention_mask=attention_mask
98
+ )[0]
99
+ bool_prompt_mask = (attention_mask == 1).to(device) # Convert to boolean
100
+ return prompt_embeds, bool_prompt_mask
101
+
102
+
103
+ @torch.no_grad()
104
+ def encode_text_classifier_free(self, prompt: str, num_samples_per_prompt: int):
105
+ # get conditional embeddings
106
+ cond_prompt_embeds, cond_prompt_mask = self.encode_text(prompt)
107
+ cond_prompt_embeds = cond_prompt_embeds.repeat_interleave(
108
+ num_samples_per_prompt, 0
109
+ )
110
+ cond_prompt_mask = cond_prompt_mask.repeat_interleave(
111
+ num_samples_per_prompt, 0
112
+ )
113
+
114
+ # get unconditional embeddings for classifier free guidance
115
+ uncond_tokens = [""] * len(prompt)
116
+ negative_prompt_embeds, uncond_prompt_mask = self.encode_text(
117
+ uncond_tokens, max_length=cond_prompt_embeds.shape[1], padding="max_length"
118
+ )
119
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
120
+ num_samples_per_prompt, 0
121
+ )
122
+ uncond_prompt_mask = uncond_prompt_mask.repeat_interleave(
123
+ num_samples_per_prompt, 0
124
+ )
125
+
126
+ """ For classifier-free guidance, we need to do two forward passes.
127
+ We concatenate the unconditional and text embeddings into a single batch
128
+ """
129
+ prompt_embeds = torch.cat([negative_prompt_embeds, cond_prompt_embeds])
130
+ prompt_mask = torch.cat([uncond_prompt_mask, cond_prompt_mask])
131
+
132
+ return prompt_embeds, prompt_mask, cond_prompt_embeds, cond_prompt_mask
133
+
134
+
135
+ def forward(
136
+ self, prompt: str, cfg_scale_input: float = 3., cfg_scale_post: float = 1.,
137
+ num_steps: int = 1, num_samples: int = 1, sr: int = 16000
138
+ ):
139
+ self.check_eval_mode()
140
+ device = self.text_encoder.device
141
+ use_cf_guidance = cfg_scale_post > 1.
142
+
143
+ # Get prompt embeddings
144
+ prompt_embeds_cf, prompt_mask_cf, prompt_embeds, prompt_mask = \
145
+ self.encode_text_classifier_free(prompt, num_samples)
146
+ encoder_states, encoder_att_mask = \
147
+ (prompt_embeds_cf, prompt_mask_cf) if use_cf_guidance \
148
+ else (prompt_embeds, prompt_mask)
149
+
150
+ # Prepare noise
151
+ num_channels_latents = self.unet.config.in_channels
152
+ latent_shape = (len(prompt) * num_samples, num_channels_latents, 256, 16)
153
+ noise = randn_tensor(
154
+ latent_shape, generator=None, device=device, dtype=prompt_embeds.dtype
155
+ )
156
+
157
+ # Query the inference scheduler to obtain the time steps.
158
+ # The time steps spread between 0 and training time steps
159
+ self.scheduler.set_timesteps(18, device=device) # Set this to training steps first
160
+ z_N = noise * self.scheduler.init_noise_sigma
161
+
162
+ def calc_zhat_0(z_n: Tensor, t: int):
163
+ """ Query the consistency model to get zhat_0, which is the denoised embedding.
164
+ Args:
165
+ z_n (Tensor): The noisy embedding.
166
+ t (int): The time step.
167
+ Returns:
168
+ Tensor: The denoised embedding.
169
+ """
170
+ # expand the latents if we are doing classifier free guidance
171
+ z_n_input = torch.cat([z_n] * 2) if use_cf_guidance else z_n
172
+ # Scale model input as required for some schedules.
173
+ z_n_input = self.scheduler.scale_model_input(z_n_input, t)
174
+
175
+ # Get zhat_0 from the model
176
+ zhat_0 = self.unet(
177
+ z_n_input, t, guidance=cfg_scale_input,
178
+ encoder_hidden_states=encoder_states, encoder_attention_mask=encoder_att_mask
179
+ ).sample
180
+
181
+ # Perform external classifier-free guidance
182
+ if use_cf_guidance:
183
+ zhat_0_uncond, zhat_0_cond = zhat_0.chunk(2)
184
+ zhat_0 = (1 - cfg_scale_post) * zhat_0_uncond + cfg_scale_post * zhat_0_cond
185
+
186
+ return zhat_0
187
+
188
+ # Query the consistency model
189
+ zhat_0 = calc_zhat_0(z_N, self.scheduler.timesteps[0])
190
+
191
+ # Iteratively query the consistency model if requested
192
+ self.scheduler.set_timesteps(num_steps, device=device)
193
+
194
+ for t in self.scheduler.timesteps[1::2]: # 2 is the order of the scheduler
195
+ zhat_n = self.scheduler.add_noise(zhat_0, torch.randn_like(zhat_0), t)
196
+ # Calculate new zhat_0
197
+ zhat_0 = calc_zhat_0(zhat_n, t)
198
+
199
+ mel = self.vae.decode_first_stage(zhat_0.float())
200
+ return self.vae.decode_to_waveform(mel)[:, :int(sr * 9.5)] # Truncate to 9.6 seconds
consistencytta_clapft_ckpt/.DS_Store ADDED
Binary file (6.15 kB). View file
 
diffusers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_heun_discrete import HeunDiscreteScheduler
2
+ from .models.unet_2d_condition_guided import UNet2DConditionGuidedModel
diffusers/models/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils.import_utils import is_torch_available
16
+
17
+
18
+ if is_torch_available():
19
+ from .modeling_utils import ModelMixin
20
+ from .prior_transformer import PriorTransformer
21
+ from .unet_2d import UNet2DModel
22
+ from .unet_2d_condition import UNet2DConditionModel
23
+ from .unet_2d_condition_guided import UNet2DConditionGuidedModel
diffusers/models/activations.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def get_activation(act_fn):
5
+ if act_fn in ["swish", "silu"]:
6
+ return nn.SiLU()
7
+ elif act_fn == "mish":
8
+ return nn.Mish()
9
+ elif act_fn == "gelu":
10
+ return nn.GELU()
11
+ else:
12
+ raise ValueError(f"Unsupported activation function: {act_fn}")
diffusers/models/attention.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from typing import Any, Callable, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from ..utils.import_utils import is_xformers_available
22
+ from .attention_processor import Attention
23
+ from .embeddings import CombinedTimestepLabelEmbeddings
24
+
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class AttentionBlock(nn.Module):
34
+ """
35
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
36
+ to the N-d case.
37
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
38
+ Uses three q, k, v linear layers to compute attention.
39
+
40
+ Parameters:
41
+ channels (`int`): The number of channels in the input and output.
42
+ num_head_channels (`int`, *optional*):
43
+ The number of channels in each head. If None, then `num_heads` = 1.
44
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
45
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
46
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
47
+ """
48
+
49
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
50
+
51
+ def __init__(
52
+ self,
53
+ channels: int,
54
+ num_head_channels: Optional[int] = None,
55
+ norm_num_groups: int = 32,
56
+ rescale_output_factor: float = 1.0,
57
+ eps: float = 1e-5,
58
+ ):
59
+ super().__init__()
60
+ self.channels = channels
61
+
62
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
63
+ self.num_head_size = num_head_channels
64
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65
+
66
+ # define q,k,v as linear layers
67
+ self.query = nn.Linear(channels, channels)
68
+ self.key = nn.Linear(channels, channels)
69
+ self.value = nn.Linear(channels, channels)
70
+
71
+ self.rescale_output_factor = rescale_output_factor
72
+ self.proj_attn = nn.Linear(channels, channels, bias=True)
73
+
74
+ self._use_memory_efficient_attention_xformers = False
75
+ self._attention_op = None
76
+
77
+ def reshape_heads_to_batch_dim(self, tensor):
78
+ batch_size, seq_len, dim = tensor.shape
79
+ head_size = self.num_heads
80
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
81
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
82
+ return tensor
83
+
84
+ def reshape_batch_dim_to_heads(self, tensor):
85
+ batch_size, seq_len, dim = tensor.shape
86
+ head_size = self.num_heads
87
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
88
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
89
+ return tensor
90
+
91
+ def set_use_memory_efficient_attention_xformers(
92
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
93
+ ):
94
+ if use_memory_efficient_attention_xformers:
95
+ if not is_xformers_available():
96
+ raise ModuleNotFoundError(
97
+ (
98
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
99
+ " xformers"
100
+ ),
101
+ name="xformers",
102
+ )
103
+ elif not torch.cuda.is_available():
104
+ raise ValueError(
105
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
106
+ " only available for GPU "
107
+ )
108
+ else:
109
+ try:
110
+ # Make sure we can run the memory efficient attention
111
+ _ = xformers.ops.memory_efficient_attention(
112
+ torch.randn((1, 2, 40), device="cuda"),
113
+ torch.randn((1, 2, 40), device="cuda"),
114
+ torch.randn((1, 2, 40), device="cuda"),
115
+ )
116
+ except Exception as e:
117
+ raise e
118
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119
+ self._attention_op = attention_op
120
+
121
+ def forward(self, hidden_states):
122
+ residual = hidden_states
123
+ batch, channel, height, width = hidden_states.shape
124
+
125
+ # norm
126
+ hidden_states = self.group_norm(hidden_states)
127
+
128
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
129
+
130
+ # proj to q, k, v
131
+ query_proj = self.query(hidden_states)
132
+ key_proj = self.key(hidden_states)
133
+ value_proj = self.value(hidden_states)
134
+
135
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
136
+
137
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
138
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
139
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
140
+
141
+ if self._use_memory_efficient_attention_xformers:
142
+ # Memory efficient attention
143
+ hidden_states = xformers.ops.memory_efficient_attention(
144
+ query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145
+ )
146
+ hidden_states = hidden_states.to(query_proj.dtype)
147
+ else:
148
+ attention_scores = torch.baddbmm(
149
+ torch.empty(
150
+ query_proj.shape[0],
151
+ query_proj.shape[1],
152
+ key_proj.shape[1],
153
+ dtype=query_proj.dtype,
154
+ device=query_proj.device,
155
+ ),
156
+ query_proj,
157
+ key_proj.transpose(-1, -2),
158
+ beta=0,
159
+ alpha=scale,
160
+ )
161
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
162
+ hidden_states = torch.bmm(attention_probs, value_proj)
163
+
164
+ # reshape hidden_states
165
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
166
+
167
+ # compute next hidden_states
168
+ hidden_states = self.proj_attn(hidden_states)
169
+
170
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
171
+
172
+ # res connect and rescale
173
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
174
+ return hidden_states
175
+
176
+
177
+ class BasicTransformerBlock(nn.Module):
178
+ r"""
179
+ A basic Transformer block.
180
+
181
+ Parameters:
182
+ dim (`int`): The number of channels in the input and output.
183
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
184
+ attention_head_dim (`int`): The number of channels in each head.
185
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
186
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
187
+ only_cross_attention (`bool`, *optional*):
188
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
189
+ double_self_attention (`bool`, *optional*):
190
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
191
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
192
+ num_embeds_ada_norm (:
193
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
194
+ attention_bias (:
195
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ dim: int,
201
+ num_attention_heads: int,
202
+ attention_head_dim: int,
203
+ dropout=0.0,
204
+ cross_attention_dim: Optional[int] = None,
205
+ activation_fn: str = "geglu",
206
+ num_embeds_ada_norm: Optional[int] = None,
207
+ attention_bias: bool = False,
208
+ only_cross_attention: bool = False,
209
+ double_self_attention: bool = False,
210
+ upcast_attention: bool = False,
211
+ norm_elementwise_affine: bool = True,
212
+ norm_type: str = "layer_norm",
213
+ final_dropout: bool = False,
214
+ ):
215
+ super().__init__()
216
+ self.only_cross_attention = only_cross_attention
217
+
218
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
219
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
220
+
221
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
222
+ raise ValueError(
223
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
224
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
225
+ )
226
+
227
+ # 1. Self-Attn
228
+ self.attn1 = Attention(
229
+ query_dim=dim,
230
+ heads=num_attention_heads,
231
+ dim_head=attention_head_dim,
232
+ dropout=dropout,
233
+ bias=attention_bias,
234
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
235
+ upcast_attention=upcast_attention,
236
+ )
237
+
238
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
239
+
240
+ # 2. Cross-Attn
241
+ if cross_attention_dim is not None or double_self_attention:
242
+ self.attn2 = Attention(
243
+ query_dim=dim,
244
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
245
+ heads=num_attention_heads,
246
+ dim_head=attention_head_dim,
247
+ dropout=dropout,
248
+ bias=attention_bias,
249
+ upcast_attention=upcast_attention,
250
+ ) # is self-attn if encoder_hidden_states is none
251
+ else:
252
+ self.attn2 = None
253
+
254
+ if self.use_ada_layer_norm:
255
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
256
+ elif self.use_ada_layer_norm_zero:
257
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
258
+ else:
259
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
260
+
261
+ if cross_attention_dim is not None or double_self_attention:
262
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
263
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
264
+ # the second cross attention block.
265
+ self.norm2 = (
266
+ AdaLayerNorm(dim, num_embeds_ada_norm)
267
+ if self.use_ada_layer_norm
268
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
269
+ )
270
+ else:
271
+ self.norm2 = None
272
+
273
+ # 3. Feed-forward
274
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.FloatTensor,
279
+ attention_mask: Optional[torch.FloatTensor] = None,
280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
281
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
282
+ timestep: Optional[torch.LongTensor] = None,
283
+ cross_attention_kwargs: Dict[str, Any] = None,
284
+ class_labels: Optional[torch.LongTensor] = None,
285
+ ):
286
+ if self.use_ada_layer_norm:
287
+ norm_hidden_states = self.norm1(hidden_states, timestep)
288
+ elif self.use_ada_layer_norm_zero:
289
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
290
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
291
+ )
292
+ else:
293
+ norm_hidden_states = self.norm1(hidden_states)
294
+
295
+ # 1. Self-Attention
296
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
297
+ attn_output = self.attn1(
298
+ norm_hidden_states,
299
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
300
+ attention_mask=attention_mask,
301
+ **cross_attention_kwargs,
302
+ )
303
+ if self.use_ada_layer_norm_zero:
304
+ attn_output = gate_msa.unsqueeze(1) * attn_output
305
+ hidden_states = attn_output + hidden_states
306
+
307
+ if self.attn2 is not None:
308
+ norm_hidden_states = (
309
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
310
+ )
311
+
312
+ # 2. Cross-Attention
313
+ attn_output = self.attn2(
314
+ norm_hidden_states,
315
+ encoder_hidden_states=encoder_hidden_states,
316
+ attention_mask=encoder_attention_mask,
317
+ **cross_attention_kwargs,
318
+ )
319
+ hidden_states = attn_output + hidden_states
320
+
321
+ # 3. Feed-forward
322
+ norm_hidden_states = self.norm3(hidden_states)
323
+
324
+ if self.use_ada_layer_norm_zero:
325
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
326
+
327
+ ff_output = self.ff(norm_hidden_states)
328
+
329
+ if self.use_ada_layer_norm_zero:
330
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
331
+
332
+ hidden_states = ff_output + hidden_states
333
+
334
+ return hidden_states
335
+
336
+
337
+ class FeedForward(nn.Module):
338
+ r"""
339
+ A feed-forward layer.
340
+
341
+ Parameters:
342
+ dim (`int`): The number of channels in the input.
343
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
344
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
345
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
346
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
347
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ dim: int,
353
+ dim_out: Optional[int] = None,
354
+ mult: int = 4,
355
+ dropout: float = 0.0,
356
+ activation_fn: str = "geglu",
357
+ final_dropout: bool = False,
358
+ ):
359
+ super().__init__()
360
+ inner_dim = int(dim * mult)
361
+ dim_out = dim_out if dim_out is not None else dim
362
+
363
+ if activation_fn == "gelu":
364
+ act_fn = GELU(dim, inner_dim)
365
+ if activation_fn == "gelu-approximate":
366
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
367
+ elif activation_fn == "geglu":
368
+ act_fn = GEGLU(dim, inner_dim)
369
+ elif activation_fn == "geglu-approximate":
370
+ act_fn = ApproximateGELU(dim, inner_dim)
371
+
372
+ self.net = nn.ModuleList([])
373
+ # project in
374
+ self.net.append(act_fn)
375
+ # project dropout
376
+ self.net.append(nn.Dropout(dropout))
377
+ # project out
378
+ self.net.append(nn.Linear(inner_dim, dim_out))
379
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
380
+ if final_dropout:
381
+ self.net.append(nn.Dropout(dropout))
382
+
383
+ def forward(self, hidden_states):
384
+ for module in self.net:
385
+ hidden_states = module(hidden_states)
386
+ return hidden_states
387
+
388
+
389
+ class GELU(nn.Module):
390
+ r"""
391
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
392
+ """
393
+
394
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
395
+ super().__init__()
396
+ self.proj = nn.Linear(dim_in, dim_out)
397
+ self.approximate = approximate
398
+
399
+ def gelu(self, gate):
400
+ if gate.device.type != "mps":
401
+ return F.gelu(gate, approximate=self.approximate)
402
+ # mps: gelu is not implemented for float16
403
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
404
+
405
+ def forward(self, hidden_states):
406
+ hidden_states = self.proj(hidden_states)
407
+ hidden_states = self.gelu(hidden_states)
408
+ return hidden_states
409
+
410
+
411
+ class GEGLU(nn.Module):
412
+ r"""
413
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
414
+
415
+ Parameters:
416
+ dim_in (`int`): The number of channels in the input.
417
+ dim_out (`int`): The number of channels in the output.
418
+ """
419
+
420
+ def __init__(self, dim_in: int, dim_out: int):
421
+ super().__init__()
422
+ self.proj = nn.Linear(dim_in, dim_out * 2)
423
+
424
+ def gelu(self, gate):
425
+ if gate.device.type != "mps":
426
+ return F.gelu(gate)
427
+ # mps: gelu is not implemented for float16
428
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
429
+
430
+ def forward(self, hidden_states):
431
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
432
+ return hidden_states * self.gelu(gate)
433
+
434
+
435
+ class ApproximateGELU(nn.Module):
436
+ """
437
+ The approximate form of Gaussian Error Linear Unit (GELU)
438
+
439
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
440
+ """
441
+
442
+ def __init__(self, dim_in: int, dim_out: int):
443
+ super().__init__()
444
+ self.proj = nn.Linear(dim_in, dim_out)
445
+
446
+ def forward(self, x):
447
+ x = self.proj(x)
448
+ return x * torch.sigmoid(1.702 * x)
449
+
450
+
451
+ class AdaLayerNorm(nn.Module):
452
+ """
453
+ Norm layer modified to incorporate timestep embeddings.
454
+ """
455
+
456
+ def __init__(self, embedding_dim, num_embeddings):
457
+ super().__init__()
458
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
459
+ self.silu = nn.SiLU()
460
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
461
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
462
+
463
+ def forward(self, x, timestep):
464
+ emb = self.linear(self.silu(self.emb(timestep)))
465
+ scale, shift = torch.chunk(emb, 2)
466
+ x = self.norm(x) * (1 + scale) + shift
467
+ return x
468
+
469
+
470
+ class AdaLayerNormZero(nn.Module):
471
+ """
472
+ Norm layer adaptive layer norm zero (adaLN-Zero).
473
+ """
474
+
475
+ def __init__(self, embedding_dim, num_embeddings):
476
+ super().__init__()
477
+
478
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
479
+
480
+ self.silu = nn.SiLU()
481
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
482
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
483
+
484
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
485
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
486
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
487
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
488
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
489
+
490
+
491
+ class AdaGroupNorm(nn.Module):
492
+ """
493
+ GroupNorm layer modified to incorporate timestep embeddings.
494
+ """
495
+
496
+ def __init__(
497
+ self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
498
+ ):
499
+ super().__init__()
500
+ self.num_groups = num_groups
501
+ self.eps = eps
502
+ self.act = None
503
+ if act_fn == "swish":
504
+ self.act = lambda x: F.silu(x)
505
+ elif act_fn == "mish":
506
+ self.act = nn.Mish()
507
+ elif act_fn == "silu":
508
+ self.act = nn.SiLU()
509
+ elif act_fn == "gelu":
510
+ self.act = nn.GELU()
511
+
512
+ self.linear = nn.Linear(embedding_dim, out_dim * 2)
513
+
514
+ def forward(self, x, emb):
515
+ if self.act:
516
+ emb = self.act(emb)
517
+ emb = self.linear(emb)
518
+ emb = emb[:, :, None, None]
519
+ scale, shift = emb.chunk(2, dim=1)
520
+
521
+ x = F.group_norm(x, self.num_groups, eps=self.eps)
522
+ x = x * (1 + scale) + shift
523
+ return x
diffusers/models/attention_processor.py ADDED
@@ -0,0 +1,1646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from ..utils.deprecation_utils import deprecate
21
+ from ..utils.torch_utils import maybe_allow_in_graph
22
+ from ..utils.import_utils import is_xformers_available
23
+ from ..utils.logging import get_logger
24
+
25
+ logger = get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+
28
+ if is_xformers_available():
29
+ import xformers
30
+ import xformers.ops
31
+ else:
32
+ xformers = None
33
+
34
+
35
+ @maybe_allow_in_graph
36
+ class Attention(nn.Module):
37
+ r"""
38
+ A cross attention layer.
39
+
40
+ Parameters:
41
+ query_dim (`int`): The number of channels in the query.
42
+ cross_attention_dim (`int`, *optional*):
43
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
44
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
45
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
46
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
47
+ bias (`bool`, *optional*, defaults to False):
48
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ query_dim: int,
54
+ cross_attention_dim: Optional[int] = None,
55
+ heads: int = 8,
56
+ dim_head: int = 64,
57
+ dropout: float = 0.0,
58
+ bias=False,
59
+ upcast_attention: bool = False,
60
+ upcast_softmax: bool = False,
61
+ cross_attention_norm: Optional[str] = None,
62
+ cross_attention_norm_num_groups: int = 32,
63
+ added_kv_proj_dim: Optional[int] = None,
64
+ norm_num_groups: Optional[int] = None,
65
+ spatial_norm_dim: Optional[int] = None,
66
+ out_bias: bool = True,
67
+ scale_qk: bool = True,
68
+ only_cross_attention: bool = False,
69
+ eps: float = 1e-5,
70
+ rescale_output_factor: float = 1.0,
71
+ residual_connection: bool = False,
72
+ _from_deprecated_attn_block=False,
73
+ processor: Optional["AttnProcessor"] = None,
74
+ ):
75
+ super().__init__()
76
+ inner_dim = dim_head * heads
77
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
78
+ self.upcast_attention = upcast_attention
79
+ self.upcast_softmax = upcast_softmax
80
+ self.rescale_output_factor = rescale_output_factor
81
+ self.residual_connection = residual_connection
82
+ self.dropout = dropout
83
+
84
+ # we make use of this private variable to know whether this class is loaded
85
+ # with an deprecated state dict so that we can convert it on the fly
86
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
87
+
88
+ self.scale_qk = scale_qk
89
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
90
+
91
+ self.heads = heads
92
+ # for slice_size > 0 the attention score computation
93
+ # is split across the batch axis to save memory
94
+ # You can set slice_size with `set_attention_slice`
95
+ self.sliceable_head_dim = heads
96
+
97
+ self.added_kv_proj_dim = added_kv_proj_dim
98
+ self.only_cross_attention = only_cross_attention
99
+
100
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
101
+ raise ValueError(
102
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
103
+ )
104
+
105
+ if norm_num_groups is not None:
106
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
107
+ else:
108
+ self.group_norm = None
109
+
110
+ if spatial_norm_dim is not None:
111
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
112
+ else:
113
+ self.spatial_norm = None
114
+
115
+ if cross_attention_norm is None:
116
+ self.norm_cross = None
117
+ elif cross_attention_norm == "layer_norm":
118
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
119
+ elif cross_attention_norm == "group_norm":
120
+ if self.added_kv_proj_dim is not None:
121
+ # The given `encoder_hidden_states` are initially of shape
122
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
123
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
124
+ # before the projection, so we need to use `added_kv_proj_dim` as
125
+ # the number of channels for the group norm.
126
+ norm_cross_num_channels = added_kv_proj_dim
127
+ else:
128
+ norm_cross_num_channels = cross_attention_dim
129
+
130
+ self.norm_cross = nn.GroupNorm(
131
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
132
+ )
133
+ else:
134
+ raise ValueError(
135
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
136
+ )
137
+
138
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
139
+
140
+ if not self.only_cross_attention:
141
+ # only relevant for the `AddedKVProcessor` classes
142
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
143
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
144
+ else:
145
+ self.to_k = None
146
+ self.to_v = None
147
+
148
+ if self.added_kv_proj_dim is not None:
149
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
150
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
151
+
152
+ self.to_out = nn.ModuleList([])
153
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
154
+ self.to_out.append(nn.Dropout(dropout))
155
+
156
+ # set attention processor
157
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
158
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
159
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
160
+ if processor is None:
161
+ processor = (
162
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
163
+ )
164
+ self.set_processor(processor)
165
+
166
+ def set_use_memory_efficient_attention_xformers(
167
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
168
+ ):
169
+ is_lora = hasattr(self, "processor") and isinstance(
170
+ self.processor,
171
+ (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
172
+ )
173
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
174
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
175
+ )
176
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
177
+ self.processor,
178
+ (
179
+ AttnAddedKVProcessor,
180
+ AttnAddedKVProcessor2_0,
181
+ SlicedAttnAddedKVProcessor,
182
+ XFormersAttnAddedKVProcessor,
183
+ LoRAAttnAddedKVProcessor,
184
+ ),
185
+ )
186
+
187
+ if use_memory_efficient_attention_xformers:
188
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
189
+ raise NotImplementedError(
190
+ f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
191
+ )
192
+ if not is_xformers_available():
193
+ raise ModuleNotFoundError(
194
+ (
195
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
196
+ " xformers"
197
+ ),
198
+ name="xformers",
199
+ )
200
+ elif not torch.cuda.is_available():
201
+ raise ValueError(
202
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
203
+ " only available for GPU "
204
+ )
205
+ else:
206
+ try:
207
+ # Make sure we can run the memory efficient attention
208
+ _ = xformers.ops.memory_efficient_attention(
209
+ torch.randn((1, 2, 40), device="cuda"),
210
+ torch.randn((1, 2, 40), device="cuda"),
211
+ torch.randn((1, 2, 40), device="cuda"),
212
+ )
213
+ except Exception as e:
214
+ raise e
215
+
216
+ if is_lora:
217
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
218
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
219
+ processor = LoRAXFormersAttnProcessor(
220
+ hidden_size=self.processor.hidden_size,
221
+ cross_attention_dim=self.processor.cross_attention_dim,
222
+ rank=self.processor.rank,
223
+ attention_op=attention_op,
224
+ )
225
+ processor.load_state_dict(self.processor.state_dict())
226
+ processor.to(self.processor.to_q_lora.up.weight.device)
227
+ elif is_custom_diffusion:
228
+ processor = CustomDiffusionXFormersAttnProcessor(
229
+ train_kv=self.processor.train_kv,
230
+ train_q_out=self.processor.train_q_out,
231
+ hidden_size=self.processor.hidden_size,
232
+ cross_attention_dim=self.processor.cross_attention_dim,
233
+ attention_op=attention_op,
234
+ )
235
+ processor.load_state_dict(self.processor.state_dict())
236
+ if hasattr(self.processor, "to_k_custom_diffusion"):
237
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
238
+ elif is_added_kv_processor:
239
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
240
+ # which uses this type of cross attention ONLY because the attention mask of format
241
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
242
+ # throw warning
243
+ logger.info(
244
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
245
+ )
246
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
247
+ else:
248
+ processor = XFormersAttnProcessor(attention_op=attention_op)
249
+ else:
250
+ if is_lora:
251
+ attn_processor_class = (
252
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
253
+ )
254
+ processor = attn_processor_class(
255
+ hidden_size=self.processor.hidden_size,
256
+ cross_attention_dim=self.processor.cross_attention_dim,
257
+ rank=self.processor.rank,
258
+ )
259
+ processor.load_state_dict(self.processor.state_dict())
260
+ processor.to(self.processor.to_q_lora.up.weight.device)
261
+ elif is_custom_diffusion:
262
+ processor = CustomDiffusionAttnProcessor(
263
+ train_kv=self.processor.train_kv,
264
+ train_q_out=self.processor.train_q_out,
265
+ hidden_size=self.processor.hidden_size,
266
+ cross_attention_dim=self.processor.cross_attention_dim,
267
+ )
268
+ processor.load_state_dict(self.processor.state_dict())
269
+ if hasattr(self.processor, "to_k_custom_diffusion"):
270
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
271
+ else:
272
+ # set attention processor
273
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
274
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
275
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
276
+ processor = (
277
+ AttnProcessor2_0()
278
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
279
+ else AttnProcessor()
280
+ )
281
+
282
+ self.set_processor(processor)
283
+
284
+ def set_attention_slice(self, slice_size):
285
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
286
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
287
+
288
+ if slice_size is not None and self.added_kv_proj_dim is not None:
289
+ processor = SlicedAttnAddedKVProcessor(slice_size)
290
+ elif slice_size is not None:
291
+ processor = SlicedAttnProcessor(slice_size)
292
+ elif self.added_kv_proj_dim is not None:
293
+ processor = AttnAddedKVProcessor()
294
+ else:
295
+ # set attention processor
296
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
297
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
298
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
299
+ processor = (
300
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
301
+ )
302
+
303
+ self.set_processor(processor)
304
+
305
+ def set_processor(self, processor: "AttnProcessor"):
306
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
307
+ # pop `processor` from `self._modules`
308
+ if (
309
+ hasattr(self, "processor")
310
+ and isinstance(self.processor, torch.nn.Module)
311
+ and not isinstance(processor, torch.nn.Module)
312
+ ):
313
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
314
+ self._modules.pop("processor")
315
+
316
+ self.processor = processor
317
+
318
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
319
+ # The `Attention` class can call different attention processors / attention functions
320
+ # here we simply pass along all tensors to the selected processor class
321
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
322
+ return self.processor(
323
+ self,
324
+ hidden_states,
325
+ encoder_hidden_states=encoder_hidden_states,
326
+ attention_mask=attention_mask,
327
+ **cross_attention_kwargs,
328
+ )
329
+
330
+ def batch_to_head_dim(self, tensor):
331
+ head_size = self.heads
332
+ batch_size, seq_len, dim = tensor.shape
333
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
334
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
335
+ return tensor
336
+
337
+ def head_to_batch_dim(self, tensor, out_dim=3):
338
+ head_size = self.heads
339
+ batch_size, seq_len, dim = tensor.shape
340
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
341
+ tensor = tensor.permute(0, 2, 1, 3)
342
+
343
+ if out_dim == 3:
344
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
345
+
346
+ return tensor
347
+
348
+ def get_attention_scores(self, query, key, attention_mask=None):
349
+ dtype = query.dtype
350
+ if self.upcast_attention:
351
+ query = query.float()
352
+ key = key.float()
353
+
354
+ if attention_mask is None:
355
+ baddbmm_input = torch.empty(
356
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
357
+ )
358
+ beta = 0
359
+ else:
360
+ baddbmm_input = attention_mask
361
+ beta = 1
362
+
363
+ attention_scores = torch.baddbmm(
364
+ baddbmm_input,
365
+ query,
366
+ key.transpose(-1, -2),
367
+ beta=beta,
368
+ alpha=self.scale,
369
+ )
370
+ del baddbmm_input
371
+
372
+ if self.upcast_softmax:
373
+ attention_scores = attention_scores.float()
374
+
375
+ attention_probs = attention_scores.softmax(dim=-1)
376
+ del attention_scores
377
+
378
+ attention_probs = attention_probs.to(dtype)
379
+
380
+ return attention_probs
381
+
382
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
383
+ if batch_size is None:
384
+ deprecate(
385
+ "batch_size=None",
386
+ "0.0.15",
387
+ (
388
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
389
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
390
+ " `prepare_attention_mask` when preparing the attention_mask."
391
+ ),
392
+ )
393
+ batch_size = 1
394
+
395
+ head_size = self.heads
396
+ if attention_mask is None:
397
+ return attention_mask
398
+
399
+ current_length: int = attention_mask.shape[-1]
400
+ if current_length != target_length:
401
+ if attention_mask.device.type == "mps":
402
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
403
+ # Instead, we can manually construct the padding tensor.
404
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
405
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
406
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
407
+ else:
408
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
409
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
410
+ # remaining_length: int = target_length - current_length
411
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
412
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
413
+
414
+ if out_dim == 3:
415
+ if attention_mask.shape[0] < batch_size * head_size:
416
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
417
+ elif out_dim == 4:
418
+ attention_mask = attention_mask.unsqueeze(1)
419
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
420
+
421
+ return attention_mask
422
+
423
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
424
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
425
+
426
+ if isinstance(self.norm_cross, nn.LayerNorm):
427
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
428
+ elif isinstance(self.norm_cross, nn.GroupNorm):
429
+ # Group norm norms along the channels dimension and expects
430
+ # input to be in the shape of (N, C, *). In this case, we want
431
+ # to norm along the hidden dimension, so we need to move
432
+ # (batch_size, sequence_length, hidden_size) ->
433
+ # (batch_size, hidden_size, sequence_length)
434
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
435
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
436
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
437
+ else:
438
+ assert False
439
+
440
+ return encoder_hidden_states
441
+
442
+
443
+ class AttnProcessor:
444
+ r"""
445
+ Default processor for performing attention-related computations.
446
+ """
447
+
448
+ def __call__(
449
+ self,
450
+ attn: Attention,
451
+ hidden_states,
452
+ encoder_hidden_states=None,
453
+ attention_mask=None,
454
+ temb=None,
455
+ ):
456
+ residual = hidden_states
457
+
458
+ if attn.spatial_norm is not None:
459
+ hidden_states = attn.spatial_norm(hidden_states, temb)
460
+
461
+ input_ndim = hidden_states.ndim
462
+
463
+ if input_ndim == 4:
464
+ batch_size, channel, height, width = hidden_states.shape
465
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
466
+
467
+ batch_size, sequence_length, _ = (
468
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
469
+ )
470
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
471
+
472
+ if attn.group_norm is not None:
473
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
474
+
475
+ query = attn.to_q(hidden_states)
476
+
477
+ if encoder_hidden_states is None:
478
+ encoder_hidden_states = hidden_states
479
+ elif attn.norm_cross:
480
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
481
+
482
+ key = attn.to_k(encoder_hidden_states)
483
+ value = attn.to_v(encoder_hidden_states)
484
+
485
+ query = attn.head_to_batch_dim(query)
486
+ key = attn.head_to_batch_dim(key)
487
+ value = attn.head_to_batch_dim(value)
488
+
489
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
490
+ hidden_states = torch.bmm(attention_probs, value)
491
+ hidden_states = attn.batch_to_head_dim(hidden_states)
492
+
493
+ # linear proj
494
+ hidden_states = attn.to_out[0](hidden_states)
495
+ # dropout
496
+ hidden_states = attn.to_out[1](hidden_states)
497
+
498
+ if input_ndim == 4:
499
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
500
+
501
+ if attn.residual_connection:
502
+ hidden_states = hidden_states + residual
503
+
504
+ hidden_states = hidden_states / attn.rescale_output_factor
505
+
506
+ return hidden_states
507
+
508
+
509
+ class LoRALinearLayer(nn.Module):
510
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None):
511
+ super().__init__()
512
+
513
+ if rank > min(in_features, out_features):
514
+ raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
515
+
516
+ self.down = nn.Linear(in_features, rank, bias=False)
517
+ self.up = nn.Linear(rank, out_features, bias=False)
518
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
519
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
520
+ self.network_alpha = network_alpha
521
+ self.rank = rank
522
+
523
+ nn.init.normal_(self.down.weight, std=1 / rank)
524
+ nn.init.zeros_(self.up.weight)
525
+
526
+ def forward(self, hidden_states):
527
+ orig_dtype = hidden_states.dtype
528
+ dtype = self.down.weight.dtype
529
+
530
+ down_hidden_states = self.down(hidden_states.to(dtype))
531
+ up_hidden_states = self.up(down_hidden_states)
532
+
533
+ if self.network_alpha is not None:
534
+ up_hidden_states *= self.network_alpha / self.rank
535
+
536
+ return up_hidden_states.to(orig_dtype)
537
+
538
+
539
+ class LoRAAttnProcessor(nn.Module):
540
+ r"""
541
+ Processor for implementing the LoRA attention mechanism.
542
+
543
+ Args:
544
+ hidden_size (`int`, *optional*):
545
+ The hidden size of the attention layer.
546
+ cross_attention_dim (`int`, *optional*):
547
+ The number of channels in the `encoder_hidden_states`.
548
+ rank (`int`, defaults to 4):
549
+ The dimension of the LoRA update matrices.
550
+ network_alpha (`int`, *optional*):
551
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
552
+ """
553
+
554
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
555
+ super().__init__()
556
+
557
+ self.hidden_size = hidden_size
558
+ self.cross_attention_dim = cross_attention_dim
559
+ self.rank = rank
560
+
561
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
562
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
563
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
564
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
565
+
566
+ def __call__(
567
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
568
+ ):
569
+ residual = hidden_states
570
+
571
+ if attn.spatial_norm is not None:
572
+ hidden_states = attn.spatial_norm(hidden_states, temb)
573
+
574
+ input_ndim = hidden_states.ndim
575
+
576
+ if input_ndim == 4:
577
+ batch_size, channel, height, width = hidden_states.shape
578
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
579
+
580
+ batch_size, sequence_length, _ = (
581
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
582
+ )
583
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
584
+
585
+ if attn.group_norm is not None:
586
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
587
+
588
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
589
+ query = attn.head_to_batch_dim(query)
590
+
591
+ if encoder_hidden_states is None:
592
+ encoder_hidden_states = hidden_states
593
+ elif attn.norm_cross:
594
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
595
+
596
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
597
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
598
+
599
+ key = attn.head_to_batch_dim(key)
600
+ value = attn.head_to_batch_dim(value)
601
+
602
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
603
+ hidden_states = torch.bmm(attention_probs, value)
604
+ hidden_states = attn.batch_to_head_dim(hidden_states)
605
+
606
+ # linear proj
607
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
608
+ # dropout
609
+ hidden_states = attn.to_out[1](hidden_states)
610
+
611
+ if input_ndim == 4:
612
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
613
+
614
+ if attn.residual_connection:
615
+ hidden_states = hidden_states + residual
616
+
617
+ hidden_states = hidden_states / attn.rescale_output_factor
618
+
619
+ return hidden_states
620
+
621
+
622
+ class CustomDiffusionAttnProcessor(nn.Module):
623
+ r"""
624
+ Processor for implementing attention for the Custom Diffusion method.
625
+
626
+ Args:
627
+ train_kv (`bool`, defaults to `True`):
628
+ Whether to newly train the key and value matrices corresponding to the text features.
629
+ train_q_out (`bool`, defaults to `True`):
630
+ Whether to newly train query matrices corresponding to the latent image features.
631
+ hidden_size (`int`, *optional*, defaults to `None`):
632
+ The hidden size of the attention layer.
633
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
634
+ The number of channels in the `encoder_hidden_states`.
635
+ out_bias (`bool`, defaults to `True`):
636
+ Whether to include the bias parameter in `train_q_out`.
637
+ dropout (`float`, *optional*, defaults to 0.0):
638
+ The dropout probability to use.
639
+ """
640
+
641
+ def __init__(
642
+ self,
643
+ train_kv=True,
644
+ train_q_out=True,
645
+ hidden_size=None,
646
+ cross_attention_dim=None,
647
+ out_bias=True,
648
+ dropout=0.0,
649
+ ):
650
+ super().__init__()
651
+ self.train_kv = train_kv
652
+ self.train_q_out = train_q_out
653
+
654
+ self.hidden_size = hidden_size
655
+ self.cross_attention_dim = cross_attention_dim
656
+
657
+ # `_custom_diffusion` id for easy serialization and loading.
658
+ if self.train_kv:
659
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
660
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
661
+ if self.train_q_out:
662
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
663
+ self.to_out_custom_diffusion = nn.ModuleList([])
664
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
665
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
666
+
667
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
668
+ batch_size, sequence_length, _ = hidden_states.shape
669
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
670
+ if self.train_q_out:
671
+ query = self.to_q_custom_diffusion(hidden_states)
672
+ else:
673
+ query = attn.to_q(hidden_states)
674
+
675
+ if encoder_hidden_states is None:
676
+ crossattn = False
677
+ encoder_hidden_states = hidden_states
678
+ else:
679
+ crossattn = True
680
+ if attn.norm_cross:
681
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
682
+
683
+ if self.train_kv:
684
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
685
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
686
+ else:
687
+ key = attn.to_k(encoder_hidden_states)
688
+ value = attn.to_v(encoder_hidden_states)
689
+
690
+ if crossattn:
691
+ detach = torch.ones_like(key)
692
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
693
+ key = detach * key + (1 - detach) * key.detach()
694
+ value = detach * value + (1 - detach) * value.detach()
695
+
696
+ query = attn.head_to_batch_dim(query)
697
+ key = attn.head_to_batch_dim(key)
698
+ value = attn.head_to_batch_dim(value)
699
+
700
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
701
+ hidden_states = torch.bmm(attention_probs, value)
702
+ hidden_states = attn.batch_to_head_dim(hidden_states)
703
+
704
+ if self.train_q_out:
705
+ # linear proj
706
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
707
+ # dropout
708
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
709
+ else:
710
+ # linear proj
711
+ hidden_states = attn.to_out[0](hidden_states)
712
+ # dropout
713
+ hidden_states = attn.to_out[1](hidden_states)
714
+
715
+ return hidden_states
716
+
717
+
718
+ class AttnAddedKVProcessor:
719
+ r"""
720
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
721
+ encoder.
722
+ """
723
+
724
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
725
+ residual = hidden_states
726
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
727
+ batch_size, sequence_length, _ = hidden_states.shape
728
+
729
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
730
+
731
+ if encoder_hidden_states is None:
732
+ encoder_hidden_states = hidden_states
733
+ elif attn.norm_cross:
734
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
735
+
736
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
737
+
738
+ query = attn.to_q(hidden_states)
739
+ query = attn.head_to_batch_dim(query)
740
+
741
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
742
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
743
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
744
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
745
+
746
+ if not attn.only_cross_attention:
747
+ key = attn.to_k(hidden_states)
748
+ value = attn.to_v(hidden_states)
749
+ key = attn.head_to_batch_dim(key)
750
+ value = attn.head_to_batch_dim(value)
751
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
752
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
753
+ else:
754
+ key = encoder_hidden_states_key_proj
755
+ value = encoder_hidden_states_value_proj
756
+
757
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
758
+ hidden_states = torch.bmm(attention_probs, value)
759
+ hidden_states = attn.batch_to_head_dim(hidden_states)
760
+
761
+ # linear proj
762
+ hidden_states = attn.to_out[0](hidden_states)
763
+ # dropout
764
+ hidden_states = attn.to_out[1](hidden_states)
765
+
766
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
767
+ hidden_states = hidden_states + residual
768
+
769
+ return hidden_states
770
+
771
+
772
+ class AttnAddedKVProcessor2_0:
773
+ r"""
774
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
775
+ learnable key and value matrices for the text encoder.
776
+ """
777
+
778
+ def __init__(self):
779
+ if not hasattr(F, "scaled_dot_product_attention"):
780
+ raise ImportError(
781
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
782
+ )
783
+
784
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
785
+ residual = hidden_states
786
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
787
+ batch_size, sequence_length, _ = hidden_states.shape
788
+
789
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
790
+
791
+ if encoder_hidden_states is None:
792
+ encoder_hidden_states = hidden_states
793
+ elif attn.norm_cross:
794
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
795
+
796
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
797
+
798
+ query = attn.to_q(hidden_states)
799
+ query = attn.head_to_batch_dim(query, out_dim=4)
800
+
801
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
802
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
803
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
804
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
805
+
806
+ if not attn.only_cross_attention:
807
+ key = attn.to_k(hidden_states)
808
+ value = attn.to_v(hidden_states)
809
+ key = attn.head_to_batch_dim(key, out_dim=4)
810
+ value = attn.head_to_batch_dim(value, out_dim=4)
811
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
812
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
813
+ else:
814
+ key = encoder_hidden_states_key_proj
815
+ value = encoder_hidden_states_value_proj
816
+
817
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
818
+ # TODO: add support for attn.scale when we move to Torch 2.1
819
+ hidden_states = F.scaled_dot_product_attention(
820
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
821
+ )
822
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
823
+
824
+ # linear proj
825
+ hidden_states = attn.to_out[0](hidden_states)
826
+ # dropout
827
+ hidden_states = attn.to_out[1](hidden_states)
828
+
829
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
830
+ hidden_states = hidden_states + residual
831
+
832
+ return hidden_states
833
+
834
+
835
+ class LoRAAttnAddedKVProcessor(nn.Module):
836
+ r"""
837
+ Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
838
+ encoder.
839
+
840
+ Args:
841
+ hidden_size (`int`, *optional*):
842
+ The hidden size of the attention layer.
843
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
844
+ The number of channels in the `encoder_hidden_states`.
845
+ rank (`int`, defaults to 4):
846
+ The dimension of the LoRA update matrices.
847
+
848
+ """
849
+
850
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
851
+ super().__init__()
852
+
853
+ self.hidden_size = hidden_size
854
+ self.cross_attention_dim = cross_attention_dim
855
+ self.rank = rank
856
+
857
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
858
+ self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
859
+ self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
860
+ self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
861
+ self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
862
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
863
+
864
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
865
+ residual = hidden_states
866
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
867
+ batch_size, sequence_length, _ = hidden_states.shape
868
+
869
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
870
+
871
+ if encoder_hidden_states is None:
872
+ encoder_hidden_states = hidden_states
873
+ elif attn.norm_cross:
874
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
875
+
876
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
877
+
878
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
879
+ query = attn.head_to_batch_dim(query)
880
+
881
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
882
+ encoder_hidden_states
883
+ )
884
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
885
+ encoder_hidden_states
886
+ )
887
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
888
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
889
+
890
+ if not attn.only_cross_attention:
891
+ key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
892
+ value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
893
+ key = attn.head_to_batch_dim(key)
894
+ value = attn.head_to_batch_dim(value)
895
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
896
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
897
+ else:
898
+ key = encoder_hidden_states_key_proj
899
+ value = encoder_hidden_states_value_proj
900
+
901
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
902
+ hidden_states = torch.bmm(attention_probs, value)
903
+ hidden_states = attn.batch_to_head_dim(hidden_states)
904
+
905
+ # linear proj
906
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
907
+ # dropout
908
+ hidden_states = attn.to_out[1](hidden_states)
909
+
910
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
911
+ hidden_states = hidden_states + residual
912
+
913
+ return hidden_states
914
+
915
+
916
+ class XFormersAttnAddedKVProcessor:
917
+ r"""
918
+ Processor for implementing memory efficient attention using xFormers.
919
+
920
+ Args:
921
+ attention_op (`Callable`, *optional*, defaults to `None`):
922
+ The base
923
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
924
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
925
+ operator.
926
+ """
927
+
928
+ def __init__(self, attention_op: Optional[Callable] = None):
929
+ self.attention_op = attention_op
930
+
931
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
932
+ residual = hidden_states
933
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
934
+ batch_size, sequence_length, _ = hidden_states.shape
935
+
936
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
937
+
938
+ if encoder_hidden_states is None:
939
+ encoder_hidden_states = hidden_states
940
+ elif attn.norm_cross:
941
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
942
+
943
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
944
+
945
+ query = attn.to_q(hidden_states)
946
+ query = attn.head_to_batch_dim(query)
947
+
948
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
949
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
950
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
951
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
952
+
953
+ if not attn.only_cross_attention:
954
+ key = attn.to_k(hidden_states)
955
+ value = attn.to_v(hidden_states)
956
+ key = attn.head_to_batch_dim(key)
957
+ value = attn.head_to_batch_dim(value)
958
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
959
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
960
+ else:
961
+ key = encoder_hidden_states_key_proj
962
+ value = encoder_hidden_states_value_proj
963
+
964
+ hidden_states = xformers.ops.memory_efficient_attention(
965
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
966
+ )
967
+ hidden_states = hidden_states.to(query.dtype)
968
+ hidden_states = attn.batch_to_head_dim(hidden_states)
969
+
970
+ # linear proj
971
+ hidden_states = attn.to_out[0](hidden_states)
972
+ # dropout
973
+ hidden_states = attn.to_out[1](hidden_states)
974
+
975
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
976
+ hidden_states = hidden_states + residual
977
+
978
+ return hidden_states
979
+
980
+
981
+ class XFormersAttnProcessor:
982
+ r"""
983
+ Processor for implementing memory efficient attention using xFormers.
984
+
985
+ Args:
986
+ attention_op (`Callable`, *optional*, defaults to `None`):
987
+ The base
988
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
989
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
990
+ operator.
991
+ """
992
+
993
+ def __init__(self, attention_op: Optional[Callable] = None):
994
+ self.attention_op = attention_op
995
+
996
+ def __call__(
997
+ self,
998
+ attn: Attention,
999
+ hidden_states: torch.FloatTensor,
1000
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1001
+ attention_mask: Optional[torch.FloatTensor] = None,
1002
+ temb: Optional[torch.FloatTensor] = None,
1003
+ ):
1004
+ residual = hidden_states
1005
+
1006
+ if attn.spatial_norm is not None:
1007
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1008
+
1009
+ input_ndim = hidden_states.ndim
1010
+
1011
+ if input_ndim == 4:
1012
+ batch_size, channel, height, width = hidden_states.shape
1013
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1014
+
1015
+ batch_size, key_tokens, _ = (
1016
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1017
+ )
1018
+
1019
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
1020
+ if attention_mask is not None:
1021
+ # expand our mask's singleton query_tokens dimension:
1022
+ # [batch*heads, 1, key_tokens] ->
1023
+ # [batch*heads, query_tokens, key_tokens]
1024
+ # so that it can be added as a bias onto the attention scores that xformers computes:
1025
+ # [batch*heads, query_tokens, key_tokens]
1026
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
1027
+ _, query_tokens, _ = hidden_states.shape
1028
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
1029
+
1030
+ if attn.group_norm is not None:
1031
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1032
+
1033
+ query = attn.to_q(hidden_states)
1034
+
1035
+ if encoder_hidden_states is None:
1036
+ encoder_hidden_states = hidden_states
1037
+ elif attn.norm_cross:
1038
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1039
+
1040
+ key = attn.to_k(encoder_hidden_states)
1041
+ value = attn.to_v(encoder_hidden_states)
1042
+
1043
+ query = attn.head_to_batch_dim(query).contiguous()
1044
+ key = attn.head_to_batch_dim(key).contiguous()
1045
+ value = attn.head_to_batch_dim(value).contiguous()
1046
+
1047
+ hidden_states = xformers.ops.memory_efficient_attention(
1048
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1049
+ )
1050
+ hidden_states = hidden_states.to(query.dtype)
1051
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1052
+
1053
+ # linear proj
1054
+ hidden_states = attn.to_out[0](hidden_states)
1055
+ # dropout
1056
+ hidden_states = attn.to_out[1](hidden_states)
1057
+
1058
+ if input_ndim == 4:
1059
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1060
+
1061
+ if attn.residual_connection:
1062
+ hidden_states = hidden_states + residual
1063
+
1064
+ hidden_states = hidden_states / attn.rescale_output_factor
1065
+
1066
+ return hidden_states
1067
+
1068
+
1069
+ class AttnProcessor2_0:
1070
+ r"""
1071
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1072
+ """
1073
+
1074
+ def __init__(self):
1075
+ if not hasattr(F, "scaled_dot_product_attention"):
1076
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1077
+
1078
+ def __call__(
1079
+ self,
1080
+ attn: Attention,
1081
+ hidden_states,
1082
+ encoder_hidden_states=None,
1083
+ attention_mask=None,
1084
+ temb=None,
1085
+ ):
1086
+ residual = hidden_states
1087
+
1088
+ if attn.spatial_norm is not None:
1089
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1090
+
1091
+ input_ndim = hidden_states.ndim
1092
+
1093
+ if input_ndim == 4:
1094
+ batch_size, channel, height, width = hidden_states.shape
1095
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1096
+
1097
+ batch_size, sequence_length, _ = (
1098
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1099
+ )
1100
+ inner_dim = hidden_states.shape[-1]
1101
+
1102
+ if attention_mask is not None:
1103
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1104
+ # scaled_dot_product_attention expects attention_mask shape to be
1105
+ # (batch, heads, source_length, target_length)
1106
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1107
+
1108
+ if attn.group_norm is not None:
1109
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1110
+
1111
+ query = attn.to_q(hidden_states)
1112
+
1113
+ if encoder_hidden_states is None:
1114
+ encoder_hidden_states = hidden_states
1115
+ elif attn.norm_cross:
1116
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1117
+
1118
+ key = attn.to_k(encoder_hidden_states)
1119
+ value = attn.to_v(encoder_hidden_states)
1120
+
1121
+ head_dim = inner_dim // attn.heads
1122
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1123
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1124
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1125
+
1126
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1127
+ # TODO: add support for attn.scale when we move to Torch 2.1
1128
+ hidden_states = F.scaled_dot_product_attention(
1129
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1130
+ )
1131
+
1132
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1133
+ hidden_states = hidden_states.to(query.dtype)
1134
+
1135
+ # linear proj
1136
+ hidden_states = attn.to_out[0](hidden_states)
1137
+ # dropout
1138
+ hidden_states = attn.to_out[1](hidden_states)
1139
+
1140
+ if input_ndim == 4:
1141
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1142
+
1143
+ if attn.residual_connection:
1144
+ hidden_states = hidden_states + residual
1145
+
1146
+ hidden_states = hidden_states / attn.rescale_output_factor
1147
+
1148
+ return hidden_states
1149
+
1150
+
1151
+ class LoRAXFormersAttnProcessor(nn.Module):
1152
+ r"""
1153
+ Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
1154
+
1155
+ Args:
1156
+ hidden_size (`int`, *optional*):
1157
+ The hidden size of the attention layer.
1158
+ cross_attention_dim (`int`, *optional*):
1159
+ The number of channels in the `encoder_hidden_states`.
1160
+ rank (`int`, defaults to 4):
1161
+ The dimension of the LoRA update matrices.
1162
+ attention_op (`Callable`, *optional*, defaults to `None`):
1163
+ The base
1164
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1165
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1166
+ operator.
1167
+ network_alpha (`int`, *optional*):
1168
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1169
+
1170
+ """
1171
+
1172
+ def __init__(
1173
+ self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1174
+ ):
1175
+ super().__init__()
1176
+
1177
+ self.hidden_size = hidden_size
1178
+ self.cross_attention_dim = cross_attention_dim
1179
+ self.rank = rank
1180
+ self.attention_op = attention_op
1181
+
1182
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1183
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1184
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1185
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1186
+
1187
+ def __call__(
1188
+ self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
1189
+ ):
1190
+ residual = hidden_states
1191
+
1192
+ if attn.spatial_norm is not None:
1193
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1194
+
1195
+ input_ndim = hidden_states.ndim
1196
+
1197
+ if input_ndim == 4:
1198
+ batch_size, channel, height, width = hidden_states.shape
1199
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1200
+
1201
+ batch_size, sequence_length, _ = (
1202
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1203
+ )
1204
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1205
+
1206
+ if attn.group_norm is not None:
1207
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1208
+
1209
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1210
+ query = attn.head_to_batch_dim(query).contiguous()
1211
+
1212
+ if encoder_hidden_states is None:
1213
+ encoder_hidden_states = hidden_states
1214
+ elif attn.norm_cross:
1215
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1216
+
1217
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1218
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1219
+
1220
+ key = attn.head_to_batch_dim(key).contiguous()
1221
+ value = attn.head_to_batch_dim(value).contiguous()
1222
+
1223
+ hidden_states = xformers.ops.memory_efficient_attention(
1224
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1225
+ )
1226
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1227
+
1228
+ # linear proj
1229
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1230
+ # dropout
1231
+ hidden_states = attn.to_out[1](hidden_states)
1232
+
1233
+ if input_ndim == 4:
1234
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1235
+
1236
+ if attn.residual_connection:
1237
+ hidden_states = hidden_states + residual
1238
+
1239
+ hidden_states = hidden_states / attn.rescale_output_factor
1240
+
1241
+ return hidden_states
1242
+
1243
+
1244
+ class LoRAAttnProcessor2_0(nn.Module):
1245
+ r"""
1246
+ Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
1247
+ attention.
1248
+
1249
+ Args:
1250
+ hidden_size (`int`):
1251
+ The hidden size of the attention layer.
1252
+ cross_attention_dim (`int`, *optional*):
1253
+ The number of channels in the `encoder_hidden_states`.
1254
+ rank (`int`, defaults to 4):
1255
+ The dimension of the LoRA update matrices.
1256
+ network_alpha (`int`, *optional*):
1257
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
1258
+ """
1259
+
1260
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
1261
+ super().__init__()
1262
+ if not hasattr(F, "scaled_dot_product_attention"):
1263
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1264
+
1265
+ self.hidden_size = hidden_size
1266
+ self.cross_attention_dim = cross_attention_dim
1267
+ self.rank = rank
1268
+
1269
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1270
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1271
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1272
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1273
+
1274
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
1275
+ residual = hidden_states
1276
+
1277
+ input_ndim = hidden_states.ndim
1278
+
1279
+ if input_ndim == 4:
1280
+ batch_size, channel, height, width = hidden_states.shape
1281
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1282
+
1283
+ batch_size, sequence_length, _ = (
1284
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1285
+ )
1286
+ inner_dim = hidden_states.shape[-1]
1287
+
1288
+ if attention_mask is not None:
1289
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1290
+ # scaled_dot_product_attention expects attention_mask shape to be
1291
+ # (batch, heads, source_length, target_length)
1292
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1293
+
1294
+ if attn.group_norm is not None:
1295
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1296
+
1297
+ query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
1298
+
1299
+ if encoder_hidden_states is None:
1300
+ encoder_hidden_states = hidden_states
1301
+ elif attn.norm_cross:
1302
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1303
+
1304
+ key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
1305
+ value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
1306
+
1307
+ head_dim = inner_dim // attn.heads
1308
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1309
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1310
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1311
+
1312
+ # TODO: add support for attn.scale when we move to Torch 2.1
1313
+ hidden_states = F.scaled_dot_product_attention(
1314
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1315
+ )
1316
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1317
+ hidden_states = hidden_states.to(query.dtype)
1318
+
1319
+ # linear proj
1320
+ hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
1321
+ # dropout
1322
+ hidden_states = attn.to_out[1](hidden_states)
1323
+
1324
+ if input_ndim == 4:
1325
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1326
+
1327
+ if attn.residual_connection:
1328
+ hidden_states = hidden_states + residual
1329
+
1330
+ hidden_states = hidden_states / attn.rescale_output_factor
1331
+
1332
+ return hidden_states
1333
+
1334
+
1335
+ class CustomDiffusionXFormersAttnProcessor(nn.Module):
1336
+ r"""
1337
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
1338
+
1339
+ Args:
1340
+ train_kv (`bool`, defaults to `True`):
1341
+ Whether to newly train the key and value matrices corresponding to the text features.
1342
+ train_q_out (`bool`, defaults to `True`):
1343
+ Whether to newly train query matrices corresponding to the latent image features.
1344
+ hidden_size (`int`, *optional*, defaults to `None`):
1345
+ The hidden size of the attention layer.
1346
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
1347
+ The number of channels in the `encoder_hidden_states`.
1348
+ out_bias (`bool`, defaults to `True`):
1349
+ Whether to include the bias parameter in `train_q_out`.
1350
+ dropout (`float`, *optional*, defaults to 0.0):
1351
+ The dropout probability to use.
1352
+ attention_op (`Callable`, *optional*, defaults to `None`):
1353
+ The base
1354
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
1355
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
1356
+ """
1357
+
1358
+ def __init__(
1359
+ self,
1360
+ train_kv=True,
1361
+ train_q_out=False,
1362
+ hidden_size=None,
1363
+ cross_attention_dim=None,
1364
+ out_bias=True,
1365
+ dropout=0.0,
1366
+ attention_op: Optional[Callable] = None,
1367
+ ):
1368
+ super().__init__()
1369
+ self.train_kv = train_kv
1370
+ self.train_q_out = train_q_out
1371
+
1372
+ self.hidden_size = hidden_size
1373
+ self.cross_attention_dim = cross_attention_dim
1374
+ self.attention_op = attention_op
1375
+
1376
+ # `_custom_diffusion` id for easy serialization and loading.
1377
+ if self.train_kv:
1378
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1379
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1380
+ if self.train_q_out:
1381
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
1382
+ self.to_out_custom_diffusion = nn.ModuleList([])
1383
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
1384
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
1385
+
1386
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1387
+ batch_size, sequence_length, _ = (
1388
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1389
+ )
1390
+
1391
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1392
+
1393
+ if self.train_q_out:
1394
+ query = self.to_q_custom_diffusion(hidden_states)
1395
+ else:
1396
+ query = attn.to_q(hidden_states)
1397
+
1398
+ if encoder_hidden_states is None:
1399
+ crossattn = False
1400
+ encoder_hidden_states = hidden_states
1401
+ else:
1402
+ crossattn = True
1403
+ if attn.norm_cross:
1404
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1405
+
1406
+ if self.train_kv:
1407
+ key = self.to_k_custom_diffusion(encoder_hidden_states)
1408
+ value = self.to_v_custom_diffusion(encoder_hidden_states)
1409
+ else:
1410
+ key = attn.to_k(encoder_hidden_states)
1411
+ value = attn.to_v(encoder_hidden_states)
1412
+
1413
+ if crossattn:
1414
+ detach = torch.ones_like(key)
1415
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
1416
+ key = detach * key + (1 - detach) * key.detach()
1417
+ value = detach * value + (1 - detach) * value.detach()
1418
+
1419
+ query = attn.head_to_batch_dim(query).contiguous()
1420
+ key = attn.head_to_batch_dim(key).contiguous()
1421
+ value = attn.head_to_batch_dim(value).contiguous()
1422
+
1423
+ hidden_states = xformers.ops.memory_efficient_attention(
1424
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1425
+ )
1426
+ hidden_states = hidden_states.to(query.dtype)
1427
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1428
+
1429
+ if self.train_q_out:
1430
+ # linear proj
1431
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
1432
+ # dropout
1433
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
1434
+ else:
1435
+ # linear proj
1436
+ hidden_states = attn.to_out[0](hidden_states)
1437
+ # dropout
1438
+ hidden_states = attn.to_out[1](hidden_states)
1439
+ return hidden_states
1440
+
1441
+
1442
+ class SlicedAttnProcessor:
1443
+ r"""
1444
+ Processor for implementing sliced attention.
1445
+
1446
+ Args:
1447
+ slice_size (`int`, *optional*):
1448
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1449
+ `attention_head_dim` must be a multiple of the `slice_size`.
1450
+ """
1451
+
1452
+ def __init__(self, slice_size):
1453
+ self.slice_size = slice_size
1454
+
1455
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
1456
+ residual = hidden_states
1457
+
1458
+ input_ndim = hidden_states.ndim
1459
+
1460
+ if input_ndim == 4:
1461
+ batch_size, channel, height, width = hidden_states.shape
1462
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1463
+
1464
+ batch_size, sequence_length, _ = (
1465
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1466
+ )
1467
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1468
+
1469
+ if attn.group_norm is not None:
1470
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1471
+
1472
+ query = attn.to_q(hidden_states)
1473
+ dim = query.shape[-1]
1474
+ query = attn.head_to_batch_dim(query)
1475
+
1476
+ if encoder_hidden_states is None:
1477
+ encoder_hidden_states = hidden_states
1478
+ elif attn.norm_cross:
1479
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1480
+
1481
+ key = attn.to_k(encoder_hidden_states)
1482
+ value = attn.to_v(encoder_hidden_states)
1483
+ key = attn.head_to_batch_dim(key)
1484
+ value = attn.head_to_batch_dim(value)
1485
+
1486
+ batch_size_attention, query_tokens, _ = query.shape
1487
+ hidden_states = torch.zeros(
1488
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1489
+ )
1490
+
1491
+ for i in range(batch_size_attention // self.slice_size):
1492
+ start_idx = i * self.slice_size
1493
+ end_idx = (i + 1) * self.slice_size
1494
+
1495
+ query_slice = query[start_idx:end_idx]
1496
+ key_slice = key[start_idx:end_idx]
1497
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1498
+
1499
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1500
+
1501
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1502
+
1503
+ hidden_states[start_idx:end_idx] = attn_slice
1504
+
1505
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1506
+
1507
+ # linear proj
1508
+ hidden_states = attn.to_out[0](hidden_states)
1509
+ # dropout
1510
+ hidden_states = attn.to_out[1](hidden_states)
1511
+
1512
+ if input_ndim == 4:
1513
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1514
+
1515
+ if attn.residual_connection:
1516
+ hidden_states = hidden_states + residual
1517
+
1518
+ hidden_states = hidden_states / attn.rescale_output_factor
1519
+
1520
+ return hidden_states
1521
+
1522
+
1523
+ class SlicedAttnAddedKVProcessor:
1524
+ r"""
1525
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
1526
+
1527
+ Args:
1528
+ slice_size (`int`, *optional*):
1529
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
1530
+ `attention_head_dim` must be a multiple of the `slice_size`.
1531
+ """
1532
+
1533
+ def __init__(self, slice_size):
1534
+ self.slice_size = slice_size
1535
+
1536
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1537
+ residual = hidden_states
1538
+
1539
+ if attn.spatial_norm is not None:
1540
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1541
+
1542
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
1543
+
1544
+ batch_size, sequence_length, _ = hidden_states.shape
1545
+
1546
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1547
+
1548
+ if encoder_hidden_states is None:
1549
+ encoder_hidden_states = hidden_states
1550
+ elif attn.norm_cross:
1551
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1552
+
1553
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1554
+
1555
+ query = attn.to_q(hidden_states)
1556
+ dim = query.shape[-1]
1557
+ query = attn.head_to_batch_dim(query)
1558
+
1559
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1560
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1561
+
1562
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
1563
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
1564
+
1565
+ if not attn.only_cross_attention:
1566
+ key = attn.to_k(hidden_states)
1567
+ value = attn.to_v(hidden_states)
1568
+ key = attn.head_to_batch_dim(key)
1569
+ value = attn.head_to_batch_dim(value)
1570
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
1571
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
1572
+ else:
1573
+ key = encoder_hidden_states_key_proj
1574
+ value = encoder_hidden_states_value_proj
1575
+
1576
+ batch_size_attention, query_tokens, _ = query.shape
1577
+ hidden_states = torch.zeros(
1578
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
1579
+ )
1580
+
1581
+ for i in range(batch_size_attention // self.slice_size):
1582
+ start_idx = i * self.slice_size
1583
+ end_idx = (i + 1) * self.slice_size
1584
+
1585
+ query_slice = query[start_idx:end_idx]
1586
+ key_slice = key[start_idx:end_idx]
1587
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1588
+
1589
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1590
+
1591
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1592
+
1593
+ hidden_states[start_idx:end_idx] = attn_slice
1594
+
1595
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1596
+
1597
+ # linear proj
1598
+ hidden_states = attn.to_out[0](hidden_states)
1599
+ # dropout
1600
+ hidden_states = attn.to_out[1](hidden_states)
1601
+
1602
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1603
+ hidden_states = hidden_states + residual
1604
+
1605
+ return hidden_states
1606
+
1607
+
1608
+ AttentionProcessor = Union[
1609
+ AttnProcessor,
1610
+ AttnProcessor2_0,
1611
+ XFormersAttnProcessor,
1612
+ SlicedAttnProcessor,
1613
+ AttnAddedKVProcessor,
1614
+ SlicedAttnAddedKVProcessor,
1615
+ AttnAddedKVProcessor2_0,
1616
+ XFormersAttnAddedKVProcessor,
1617
+ LoRAAttnProcessor,
1618
+ LoRAXFormersAttnProcessor,
1619
+ LoRAAttnProcessor2_0,
1620
+ LoRAAttnAddedKVProcessor,
1621
+ CustomDiffusionAttnProcessor,
1622
+ CustomDiffusionXFormersAttnProcessor,
1623
+ ]
1624
+
1625
+
1626
+ class SpatialNorm(nn.Module):
1627
+ """
1628
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
1629
+ """
1630
+
1631
+ def __init__(
1632
+ self,
1633
+ f_channels,
1634
+ zq_channels,
1635
+ ):
1636
+ super().__init__()
1637
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
1638
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1639
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
1640
+
1641
+ def forward(self, f, zq):
1642
+ f_size = f.shape[-2:]
1643
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
1644
+ norm_f = self.norm_layer(f)
1645
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
1646
+ return new_f
diffusers/models/dual_transformer_2d.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Optional
15
+
16
+ from torch import nn
17
+
18
+ from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
19
+
20
+
21
+ class DualTransformer2DModel(nn.Module):
22
+ """
23
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24
+
25
+ Parameters:
26
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28
+ in_channels (`int`, *optional*):
29
+ Pass if the input is continuous. The number of channels in the input and output.
30
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35
+ `ImagePositionalEmbeddings`.
36
+ num_vector_embeds (`int`, *optional*):
37
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38
+ Includes the class for the masked latent pixel.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43
+ up to but not more than steps than `num_embeds_ada_norm`.
44
+ attention_bias (`bool`, *optional*):
45
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ num_attention_heads: int = 16,
51
+ attention_head_dim: int = 88,
52
+ in_channels: Optional[int] = None,
53
+ num_layers: int = 1,
54
+ dropout: float = 0.0,
55
+ norm_num_groups: int = 32,
56
+ cross_attention_dim: Optional[int] = None,
57
+ attention_bias: bool = False,
58
+ sample_size: Optional[int] = None,
59
+ num_vector_embeds: Optional[int] = None,
60
+ activation_fn: str = "geglu",
61
+ num_embeds_ada_norm: Optional[int] = None,
62
+ ):
63
+ super().__init__()
64
+ self.transformers = nn.ModuleList(
65
+ [
66
+ Transformer2DModel(
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=attention_head_dim,
69
+ in_channels=in_channels,
70
+ num_layers=num_layers,
71
+ dropout=dropout,
72
+ norm_num_groups=norm_num_groups,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attention_bias=attention_bias,
75
+ sample_size=sample_size,
76
+ num_vector_embeds=num_vector_embeds,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ )
80
+ for _ in range(2)
81
+ ]
82
+ )
83
+
84
+ # Variables that can be set by a pipeline:
85
+
86
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
87
+ self.mix_ratio = 0.5
88
+
89
+ # The shape of `encoder_hidden_states` is expected to be
90
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91
+ self.condition_lengths = [77, 257]
92
+
93
+ # Which transformer to use to encode which condition.
94
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95
+ self.transformer_index_for_condition = [1, 0]
96
+
97
+ def forward(
98
+ self,
99
+ hidden_states,
100
+ encoder_hidden_states,
101
+ timestep=None,
102
+ attention_mask=None,
103
+ cross_attention_kwargs=None,
104
+ return_dict: bool = True,
105
+ ):
106
+ """
107
+ Args:
108
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
+ hidden_states
111
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
+ self-attention.
114
+ timestep ( `torch.long`, *optional*):
115
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
+ attention_mask (`torch.FloatTensor`, *optional*):
117
+ Optional attention mask to be applied in Attention
118
+ return_dict (`bool`, *optional*, defaults to `True`):
119
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120
+
121
+ Returns:
122
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124
+ returning a tuple, the first element is the sample tensor.
125
+ """
126
+ input_states = hidden_states
127
+
128
+ encoded_states = []
129
+ tokens_start = 0
130
+ # attention_mask is not used yet
131
+ for i in range(2):
132
+ # for each of the two transformers, pass the corresponding condition tokens
133
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134
+ transformer_index = self.transformer_index_for_condition[i]
135
+ encoded_state = self.transformers[transformer_index](
136
+ input_states,
137
+ encoder_hidden_states=condition_state,
138
+ timestep=timestep,
139
+ cross_attention_kwargs=cross_attention_kwargs,
140
+ return_dict=False,
141
+ )[0]
142
+ encoded_states.append(encoded_state - input_states)
143
+ tokens_start += self.condition_lengths[i]
144
+
145
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146
+ output_states = output_states + input_states
147
+
148
+ if not return_dict:
149
+ return (output_states,)
150
+
151
+ return Transformer2DModelOutput(sample=output_states)
diffusers/models/embeddings.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch import nn
21
+
22
+ from .activations import get_activation
23
+
24
+
25
+ def get_timestep_embedding(
26
+ timesteps: torch.Tensor,
27
+ embedding_dim: int,
28
+ flip_sin_to_cos: bool = False,
29
+ downscale_freq_shift: float = 1,
30
+ scale: float = 1,
31
+ max_period: int = 10000,
32
+ ):
33
+ """
34
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
35
+
36
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
37
+ These may be fractional.
38
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
39
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
40
+ """
41
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
42
+
43
+ half_dim = embedding_dim // 2
44
+ exponent = -math.log(max_period) * torch.arange(
45
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
46
+ )
47
+ exponent = exponent / (half_dim - downscale_freq_shift)
48
+
49
+ emb = torch.exp(exponent)
50
+ emb = timesteps[:, None].float() * emb[None, :]
51
+
52
+ # scale embeddings
53
+ emb = scale * emb
54
+
55
+ # concat sine and cosine embeddings
56
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
57
+
58
+ # flip sine and cosine embeddings
59
+ if flip_sin_to_cos:
60
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
61
+
62
+ # zero pad
63
+ if embedding_dim % 2 == 1:
64
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
65
+ return emb
66
+
67
+
68
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
69
+ """
70
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
71
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
72
+ """
73
+ grid_h = np.arange(grid_size, dtype=np.float32)
74
+ grid_w = np.arange(grid_size, dtype=np.float32)
75
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
76
+ grid = np.stack(grid, axis=0)
77
+
78
+ grid = grid.reshape([2, 1, grid_size, grid_size])
79
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
80
+ if cls_token and extra_tokens > 0:
81
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
82
+ return pos_embed
83
+
84
+
85
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
86
+ if embed_dim % 2 != 0:
87
+ raise ValueError("embed_dim must be divisible by 2")
88
+
89
+ # use half of dimensions to encode grid_h
90
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
91
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
92
+
93
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
94
+ return emb
95
+
96
+
97
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
98
+ """
99
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
100
+ """
101
+ if embed_dim % 2 != 0:
102
+ raise ValueError("embed_dim must be divisible by 2")
103
+
104
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
105
+ omega /= embed_dim / 2.0
106
+ omega = 1.0 / 10000**omega # (D/2,)
107
+
108
+ pos = pos.reshape(-1) # (M,)
109
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
110
+
111
+ emb_sin = np.sin(out) # (M, D/2)
112
+ emb_cos = np.cos(out) # (M, D/2)
113
+
114
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
115
+ return emb
116
+
117
+
118
+ class PatchEmbed(nn.Module):
119
+ """2D Image to Patch Embedding"""
120
+
121
+ def __init__(
122
+ self,
123
+ height=224,
124
+ width=224,
125
+ patch_size=16,
126
+ in_channels=3,
127
+ embed_dim=768,
128
+ layer_norm=False,
129
+ flatten=True,
130
+ bias=True,
131
+ ):
132
+ super().__init__()
133
+
134
+ num_patches = (height // patch_size) * (width // patch_size)
135
+ self.flatten = flatten
136
+ self.layer_norm = layer_norm
137
+
138
+ self.proj = nn.Conv2d(
139
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
140
+ )
141
+ if layer_norm:
142
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
143
+ else:
144
+ self.norm = None
145
+
146
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
147
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
148
+
149
+ def forward(self, latent):
150
+ latent = self.proj(latent)
151
+ if self.flatten:
152
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
153
+ if self.layer_norm:
154
+ latent = self.norm(latent)
155
+ return latent + self.pos_embed
156
+
157
+
158
+ class TimestepEmbedding(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels: int,
162
+ time_embed_dim: int,
163
+ act_fn: str = "silu",
164
+ out_dim: int = None,
165
+ post_act_fn: Optional[str] = None,
166
+ cond_proj_dim=None,
167
+ ):
168
+ super().__init__()
169
+
170
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
171
+
172
+ if cond_proj_dim is not None:
173
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
174
+ else:
175
+ self.cond_proj = None
176
+
177
+ self.act = get_activation(act_fn)
178
+
179
+ if out_dim is not None:
180
+ time_embed_dim_out = out_dim
181
+ else:
182
+ time_embed_dim_out = time_embed_dim
183
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
184
+
185
+ if post_act_fn is None:
186
+ self.post_act = None
187
+ else:
188
+ self.post_act = get_activation(post_act_fn)
189
+
190
+ def forward(self, sample, condition=None):
191
+ if condition is not None:
192
+ sample = sample + self.cond_proj(condition)
193
+ sample = self.linear_1(sample)
194
+
195
+ if self.act is not None:
196
+ sample = self.act(sample)
197
+
198
+ sample = self.linear_2(sample)
199
+
200
+ if self.post_act is not None:
201
+ sample = self.post_act(sample)
202
+ return sample
203
+
204
+
205
+ class Timesteps(nn.Module):
206
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
207
+ super().__init__()
208
+ self.num_channels = num_channels
209
+ self.flip_sin_to_cos = flip_sin_to_cos
210
+ self.downscale_freq_shift = downscale_freq_shift
211
+
212
+ def forward(self, timesteps):
213
+ t_emb = get_timestep_embedding(
214
+ timesteps,
215
+ self.num_channels,
216
+ flip_sin_to_cos=self.flip_sin_to_cos,
217
+ downscale_freq_shift=self.downscale_freq_shift,
218
+ )
219
+ return t_emb
220
+
221
+
222
+ class GaussianFourierProjection(nn.Module):
223
+ """Gaussian Fourier embeddings for noise levels."""
224
+
225
+ def __init__(
226
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
227
+ ):
228
+ super().__init__()
229
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
230
+ self.log = log
231
+ self.flip_sin_to_cos = flip_sin_to_cos
232
+
233
+ if set_W_to_weight:
234
+ # to delete later
235
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
236
+
237
+ self.weight = self.W
238
+
239
+ def forward(self, x):
240
+ if self.log:
241
+ x = torch.log(x)
242
+
243
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
244
+
245
+ if self.flip_sin_to_cos:
246
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
247
+ else:
248
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
249
+ return out
250
+
251
+
252
+ class ImagePositionalEmbeddings(nn.Module):
253
+ """
254
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
255
+ height and width of the latent space.
256
+
257
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
258
+
259
+ For VQ-diffusion:
260
+
261
+ Output vector embeddings are used as input for the transformer.
262
+
263
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
264
+
265
+ Args:
266
+ num_embed (`int`):
267
+ Number of embeddings for the latent pixels embeddings.
268
+ height (`int`):
269
+ Height of the latent image i.e. the number of height embeddings.
270
+ width (`int`):
271
+ Width of the latent image i.e. the number of width embeddings.
272
+ embed_dim (`int`):
273
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ num_embed: int,
279
+ height: int,
280
+ width: int,
281
+ embed_dim: int,
282
+ ):
283
+ super().__init__()
284
+
285
+ self.height = height
286
+ self.width = width
287
+ self.num_embed = num_embed
288
+ self.embed_dim = embed_dim
289
+
290
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
291
+ self.height_emb = nn.Embedding(self.height, embed_dim)
292
+ self.width_emb = nn.Embedding(self.width, embed_dim)
293
+
294
+ def forward(self, index):
295
+ emb = self.emb(index)
296
+
297
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
298
+
299
+ # 1 x H x D -> 1 x H x 1 x D
300
+ height_emb = height_emb.unsqueeze(2)
301
+
302
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
303
+
304
+ # 1 x W x D -> 1 x 1 x W x D
305
+ width_emb = width_emb.unsqueeze(1)
306
+
307
+ pos_emb = height_emb + width_emb
308
+
309
+ # 1 x H x W x D -> 1 x L xD
310
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
311
+
312
+ emb = emb + pos_emb[:, : emb.shape[1], :]
313
+
314
+ return emb
315
+
316
+
317
+ class LabelEmbedding(nn.Module):
318
+ """
319
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
320
+
321
+ Args:
322
+ num_classes (`int`): The number of classes.
323
+ hidden_size (`int`): The size of the vector embeddings.
324
+ dropout_prob (`float`): The probability of dropping a label.
325
+ """
326
+
327
+ def __init__(self, num_classes, hidden_size, dropout_prob):
328
+ super().__init__()
329
+ use_cfg_embedding = dropout_prob > 0
330
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
331
+ self.num_classes = num_classes
332
+ self.dropout_prob = dropout_prob
333
+
334
+ def token_drop(self, labels, force_drop_ids=None):
335
+ """
336
+ Drops labels to enable classifier-free guidance.
337
+ """
338
+ if force_drop_ids is None:
339
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
340
+ else:
341
+ drop_ids = torch.tensor(force_drop_ids == 1)
342
+ labels = torch.where(drop_ids, self.num_classes, labels)
343
+ return labels
344
+
345
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
346
+ use_dropout = self.dropout_prob > 0
347
+ if (self.training and use_dropout) or (force_drop_ids is not None):
348
+ labels = self.token_drop(labels, force_drop_ids)
349
+ embeddings = self.embedding_table(labels)
350
+ return embeddings
351
+
352
+
353
+ class TextImageProjection(nn.Module):
354
+ def __init__(
355
+ self,
356
+ text_embed_dim: int = 1024,
357
+ image_embed_dim: int = 768,
358
+ cross_attention_dim: int = 768,
359
+ num_image_text_embeds: int = 10,
360
+ ):
361
+ super().__init__()
362
+
363
+ self.num_image_text_embeds = num_image_text_embeds
364
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
365
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
366
+
367
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
368
+ batch_size = text_embeds.shape[0]
369
+
370
+ # image
371
+ image_text_embeds = self.image_embeds(image_embeds)
372
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
373
+
374
+ # text
375
+ text_embeds = self.text_proj(text_embeds)
376
+
377
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
378
+
379
+
380
+ class CombinedTimestepLabelEmbeddings(nn.Module):
381
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
382
+ super().__init__()
383
+
384
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
385
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
386
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
387
+
388
+ def forward(self, timestep, class_labels, hidden_dtype=None):
389
+ timesteps_proj = self.time_proj(timestep)
390
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
391
+
392
+ class_labels = self.class_embedder(class_labels) # (N, D)
393
+
394
+ conditioning = timesteps_emb + class_labels # (N, D)
395
+
396
+ return conditioning
397
+
398
+
399
+ class TextTimeEmbedding(nn.Module):
400
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
401
+ super().__init__()
402
+ self.norm1 = nn.LayerNorm(encoder_dim)
403
+ self.pool = AttentionPooling(num_heads, encoder_dim)
404
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
405
+ self.norm2 = nn.LayerNorm(time_embed_dim)
406
+
407
+ def forward(self, hidden_states):
408
+ hidden_states = self.norm1(hidden_states)
409
+ hidden_states = self.pool(hidden_states)
410
+ hidden_states = self.proj(hidden_states)
411
+ hidden_states = self.norm2(hidden_states)
412
+ return hidden_states
413
+
414
+
415
+ class TextImageTimeEmbedding(nn.Module):
416
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
417
+ super().__init__()
418
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
419
+ self.text_norm = nn.LayerNorm(time_embed_dim)
420
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
421
+
422
+ def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
423
+ # text
424
+ time_text_embeds = self.text_proj(text_embeds)
425
+ time_text_embeds = self.text_norm(time_text_embeds)
426
+
427
+ # image
428
+ time_image_embeds = self.image_proj(image_embeds)
429
+
430
+ return time_image_embeds + time_text_embeds
431
+
432
+
433
+ class AttentionPooling(nn.Module):
434
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
435
+
436
+ def __init__(self, num_heads, embed_dim, dtype=None):
437
+ super().__init__()
438
+ self.dtype = dtype
439
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
440
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
441
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
442
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
443
+ self.num_heads = num_heads
444
+ self.dim_per_head = embed_dim // self.num_heads
445
+
446
+ def forward(self, x):
447
+ bs, length, width = x.size()
448
+
449
+ def shape(x):
450
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
451
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
452
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
453
+ x = x.transpose(1, 2)
454
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
455
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
456
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
457
+ x = x.transpose(1, 2)
458
+ return x
459
+
460
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
461
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
462
+
463
+ # (bs*n_heads, class_token_length, dim_per_head)
464
+ q = shape(self.q_proj(class_token))
465
+ # (bs*n_heads, length+class_token_length, dim_per_head)
466
+ k = shape(self.k_proj(x))
467
+ v = shape(self.v_proj(x))
468
+
469
+ # (bs*n_heads, class_token_length, length+class_token_length):
470
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
471
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
472
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
473
+
474
+ # (bs*n_heads, dim_per_head, class_token_length)
475
+ a = torch.einsum("bts,bcs->bct", weight, v)
476
+
477
+ # (bs, length+1, width)
478
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
479
+
480
+ return a[:, 0, :] # cls_token
diffusers/models/loaders.py ADDED
@@ -0,0 +1,1481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import warnings
17
+ from collections import defaultdict
18
+ from pathlib import Path
19
+ from typing import Callable, Dict, List, Optional, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from huggingface_hub import hf_hub_download
24
+
25
+ from .attention_processor import (
26
+ AttnAddedKVProcessor,
27
+ AttnAddedKVProcessor2_0,
28
+ CustomDiffusionAttnProcessor,
29
+ CustomDiffusionXFormersAttnProcessor,
30
+ LoRAAttnAddedKVProcessor,
31
+ LoRAAttnProcessor,
32
+ LoRAAttnProcessor2_0,
33
+ LoRAXFormersAttnProcessor,
34
+ SlicedAttnAddedKVProcessor,
35
+ XFormersAttnProcessor,
36
+ )
37
+ from ..utils.constants import DIFFUSERS_CACHE, TEXT_ENCODER_ATTN_MODULE
38
+ from ..utils.hub_utils import HF_HUB_OFFLINE, _get_model_file
39
+ from ..utils.deprecation_utils import deprecate
40
+ from ..utils.import_utils import is_safetensors_available, is_transformers_available
41
+ from ..utils.logging import get_logger
42
+
43
+ if is_safetensors_available():
44
+ import safetensors
45
+
46
+ if is_transformers_available():
47
+ from transformers import PreTrainedModel, PreTrainedTokenizer
48
+
49
+
50
+ logger = get_logger(__name__)
51
+
52
+ TEXT_ENCODER_NAME = "text_encoder"
53
+ UNET_NAME = "unet"
54
+
55
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
56
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
57
+
58
+ TEXT_INVERSION_NAME = "learned_embeds.bin"
59
+ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
60
+
61
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
62
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
63
+
64
+
65
+ class AttnProcsLayers(torch.nn.Module):
66
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
67
+ super().__init__()
68
+ self.layers = torch.nn.ModuleList(state_dict.values())
69
+ self.mapping = dict(enumerate(state_dict.keys()))
70
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
71
+
72
+ # .processor for unet, .self_attn for text encoder
73
+ self.split_keys = [".processor", ".self_attn"]
74
+
75
+ # we add a hook to state_dict() and load_state_dict() so that the
76
+ # naming fits with `unet.attn_processors`
77
+ def map_to(module, state_dict, *args, **kwargs):
78
+ new_state_dict = {}
79
+ for key, value in state_dict.items():
80
+ num = int(key.split(".")[1]) # 0 is always "layers"
81
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
82
+ new_state_dict[new_key] = value
83
+
84
+ return new_state_dict
85
+
86
+ def remap_key(key, state_dict):
87
+ for k in self.split_keys:
88
+ if k in key:
89
+ return key.split(k)[0] + k
90
+
91
+ raise ValueError(
92
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
93
+ )
94
+
95
+ def map_from(module, state_dict, *args, **kwargs):
96
+ all_keys = list(state_dict.keys())
97
+ for key in all_keys:
98
+ replace_key = remap_key(key, state_dict)
99
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
100
+ state_dict[new_key] = state_dict[key]
101
+ del state_dict[key]
102
+
103
+ self._register_state_dict_hook(map_to)
104
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
105
+
106
+
107
+ class UNet2DConditionLoadersMixin:
108
+ text_encoder_name = TEXT_ENCODER_NAME
109
+ unet_name = UNET_NAME
110
+
111
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
112
+ r"""
113
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
114
+ defined in
115
+ [`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
116
+ and be a `torch.nn.Module` class.
117
+
118
+ Parameters:
119
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
120
+ Can be either:
121
+
122
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
123
+ the Hub.
124
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
125
+ with [`ModelMixin.save_pretrained`].
126
+ - A [torch state
127
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
128
+
129
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
130
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
131
+ is not used.
132
+ force_download (`bool`, *optional*, defaults to `False`):
133
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
134
+ cached versions if they exist.
135
+ resume_download (`bool`, *optional*, defaults to `False`):
136
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
137
+ incompletely downloaded files are deleted.
138
+ proxies (`Dict[str, str]`, *optional*):
139
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
140
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
141
+ local_files_only (`bool`, *optional*, defaults to `False`):
142
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
143
+ won't be downloaded from the Hub.
144
+ use_auth_token (`str` or *bool*, *optional*):
145
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
146
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
147
+ revision (`str`, *optional*, defaults to `"main"`):
148
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
149
+ allowed by Git.
150
+ subfolder (`str`, *optional*, defaults to `""`):
151
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
152
+ mirror (`str`, *optional*):
153
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
154
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
155
+ information.
156
+
157
+ """
158
+
159
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
160
+ force_download = kwargs.pop("force_download", False)
161
+ resume_download = kwargs.pop("resume_download", False)
162
+ proxies = kwargs.pop("proxies", None)
163
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
164
+ use_auth_token = kwargs.pop("use_auth_token", None)
165
+ revision = kwargs.pop("revision", None)
166
+ subfolder = kwargs.pop("subfolder", None)
167
+ weight_name = kwargs.pop("weight_name", None)
168
+ use_safetensors = kwargs.pop("use_safetensors", None)
169
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
170
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
171
+ network_alpha = kwargs.pop("network_alpha", None)
172
+
173
+ if use_safetensors and not is_safetensors_available():
174
+ raise ValueError(
175
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
176
+ )
177
+
178
+ allow_pickle = False
179
+ if use_safetensors is None:
180
+ use_safetensors = is_safetensors_available()
181
+ allow_pickle = True
182
+
183
+ user_agent = {
184
+ "file_type": "attn_procs_weights",
185
+ "framework": "pytorch",
186
+ }
187
+
188
+ model_file = None
189
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
190
+ # Let's first try to load .safetensors weights
191
+ if (use_safetensors and weight_name is None) or (
192
+ weight_name is not None and weight_name.endswith(".safetensors")
193
+ ):
194
+ try:
195
+ model_file = _get_model_file(
196
+ pretrained_model_name_or_path_or_dict,
197
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
198
+ cache_dir=cache_dir,
199
+ force_download=force_download,
200
+ resume_download=resume_download,
201
+ proxies=proxies,
202
+ local_files_only=local_files_only,
203
+ use_auth_token=use_auth_token,
204
+ revision=revision,
205
+ subfolder=subfolder,
206
+ user_agent=user_agent,
207
+ )
208
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
209
+ except IOError as e:
210
+ if not allow_pickle:
211
+ raise e
212
+ # try loading non-safetensors weights
213
+ pass
214
+ if model_file is None:
215
+ model_file = _get_model_file(
216
+ pretrained_model_name_or_path_or_dict,
217
+ weights_name=weight_name or LORA_WEIGHT_NAME,
218
+ cache_dir=cache_dir,
219
+ force_download=force_download,
220
+ resume_download=resume_download,
221
+ proxies=proxies,
222
+ local_files_only=local_files_only,
223
+ use_auth_token=use_auth_token,
224
+ revision=revision,
225
+ subfolder=subfolder,
226
+ user_agent=user_agent,
227
+ )
228
+ state_dict = torch.load(model_file, map_location="cpu")
229
+ else:
230
+ state_dict = pretrained_model_name_or_path_or_dict
231
+
232
+ # fill attn processors
233
+ attn_processors = {}
234
+
235
+ is_lora = all("lora" in k for k in state_dict.keys())
236
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
237
+
238
+ if is_lora:
239
+ is_new_lora_format = all(
240
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
241
+ )
242
+ if is_new_lora_format:
243
+ # Strip the `"unet"` prefix.
244
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
245
+ if is_text_encoder_present:
246
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
247
+ warnings.warn(warn_message)
248
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
249
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
250
+
251
+ lora_grouped_dict = defaultdict(dict)
252
+ for key, value in state_dict.items():
253
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
254
+ lora_grouped_dict[attn_processor_key][sub_key] = value
255
+
256
+ for key, value_dict in lora_grouped_dict.items():
257
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
258
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
259
+
260
+ attn_processor = self
261
+ for sub_key in key.split("."):
262
+ attn_processor = getattr(attn_processor, sub_key)
263
+
264
+ if isinstance(
265
+ attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
266
+ ):
267
+ cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
268
+ attn_processor_class = LoRAAttnAddedKVProcessor
269
+ else:
270
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
271
+ if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
272
+ attn_processor_class = LoRAXFormersAttnProcessor
273
+ else:
274
+ attn_processor_class = (
275
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
276
+ )
277
+
278
+ attn_processors[key] = attn_processor_class(
279
+ hidden_size=hidden_size,
280
+ cross_attention_dim=cross_attention_dim,
281
+ rank=rank,
282
+ network_alpha=network_alpha,
283
+ )
284
+ attn_processors[key].load_state_dict(value_dict)
285
+ elif is_custom_diffusion:
286
+ custom_diffusion_grouped_dict = defaultdict(dict)
287
+ for key, value in state_dict.items():
288
+ if len(value) == 0:
289
+ custom_diffusion_grouped_dict[key] = {}
290
+ else:
291
+ if "to_out" in key:
292
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
293
+ else:
294
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
295
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
296
+
297
+ for key, value_dict in custom_diffusion_grouped_dict.items():
298
+ if len(value_dict) == 0:
299
+ attn_processors[key] = CustomDiffusionAttnProcessor(
300
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
301
+ )
302
+ else:
303
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
304
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
305
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
306
+ attn_processors[key] = CustomDiffusionAttnProcessor(
307
+ train_kv=True,
308
+ train_q_out=train_q_out,
309
+ hidden_size=hidden_size,
310
+ cross_attention_dim=cross_attention_dim,
311
+ )
312
+ attn_processors[key].load_state_dict(value_dict)
313
+ else:
314
+ raise ValueError(
315
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
316
+ )
317
+
318
+ # set correct dtype & device
319
+ attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
320
+
321
+ # set layers
322
+ self.set_attn_processor(attn_processors)
323
+
324
+ def save_attn_procs(
325
+ self,
326
+ save_directory: Union[str, os.PathLike],
327
+ is_main_process: bool = True,
328
+ weight_name: str = None,
329
+ save_function: Callable = None,
330
+ safe_serialization: bool = False,
331
+ **kwargs,
332
+ ):
333
+ r"""
334
+ Save an attention processor to a directory so that it can be reloaded using the
335
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
336
+
337
+ Arguments:
338
+ save_directory (`str` or `os.PathLike`):
339
+ Directory to save an attention processor to. Will be created if it doesn't exist.
340
+ is_main_process (`bool`, *optional*, defaults to `True`):
341
+ Whether the process calling this is the main process or not. Useful during distributed training and you
342
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
343
+ process to avoid race conditions.
344
+ save_function (`Callable`):
345
+ The function to use to save the state dictionary. Useful during distributed training when you need to
346
+ replace `torch.save` with another method. Can be configured with the environment variable
347
+ `DIFFUSERS_SAVE_MODE`.
348
+
349
+ """
350
+ weight_name = weight_name or deprecate(
351
+ "weights_name",
352
+ "0.20.0",
353
+ "`weights_name` is deprecated, please use `weight_name` instead.",
354
+ take_from=kwargs,
355
+ )
356
+ if os.path.isfile(save_directory):
357
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
358
+ return
359
+
360
+ if save_function is None:
361
+ if safe_serialization:
362
+
363
+ def save_function(weights, filename):
364
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
365
+
366
+ else:
367
+ save_function = torch.save
368
+
369
+ os.makedirs(save_directory, exist_ok=True)
370
+
371
+ is_custom_diffusion = any(
372
+ isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
373
+ for (_, x) in self.attn_processors.items()
374
+ )
375
+ if is_custom_diffusion:
376
+ model_to_save = AttnProcsLayers(
377
+ {
378
+ y: x
379
+ for (y, x) in self.attn_processors.items()
380
+ if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
381
+ }
382
+ )
383
+ state_dict = model_to_save.state_dict()
384
+ for name, attn in self.attn_processors.items():
385
+ if len(attn.state_dict()) == 0:
386
+ state_dict[name] = {}
387
+ else:
388
+ model_to_save = AttnProcsLayers(self.attn_processors)
389
+ state_dict = model_to_save.state_dict()
390
+
391
+ if weight_name is None:
392
+ if safe_serialization:
393
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
394
+ else:
395
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
396
+
397
+ # Save the model
398
+ save_function(state_dict, os.path.join(save_directory, weight_name))
399
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
400
+
401
+
402
+ class TextualInversionLoaderMixin:
403
+ r"""
404
+ Load textual inversion tokens and embeddings to the tokenizer and text encoder.
405
+ """
406
+
407
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
408
+ r"""
409
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
410
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
411
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
412
+
413
+ Parameters:
414
+ prompt (`str` or list of `str`):
415
+ The prompt or prompts to guide the image generation.
416
+ tokenizer (`PreTrainedTokenizer`):
417
+ The tokenizer responsible for encoding the prompt into input tokens.
418
+
419
+ Returns:
420
+ `str` or list of `str`: The converted prompt
421
+ """
422
+ if not isinstance(prompt, List):
423
+ prompts = [prompt]
424
+ else:
425
+ prompts = prompt
426
+
427
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
428
+
429
+ if not isinstance(prompt, List):
430
+ return prompts[0]
431
+
432
+ return prompts
433
+
434
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
435
+ r"""
436
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
437
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
438
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
439
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
440
+
441
+ Parameters:
442
+ prompt (`str`):
443
+ The prompt to guide the image generation.
444
+ tokenizer (`PreTrainedTokenizer`):
445
+ The tokenizer responsible for encoding the prompt into input tokens.
446
+
447
+ Returns:
448
+ `str`: The converted prompt
449
+ """
450
+ tokens = tokenizer.tokenize(prompt)
451
+ unique_tokens = set(tokens)
452
+ for token in unique_tokens:
453
+ if token in tokenizer.added_tokens_encoder:
454
+ replacement = token
455
+ i = 1
456
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
457
+ replacement += f" {token}_{i}"
458
+ i += 1
459
+
460
+ prompt = prompt.replace(token, replacement)
461
+
462
+ return prompt
463
+
464
+ def load_textual_inversion(
465
+ self,
466
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
467
+ token: Optional[Union[str, List[str]]] = None,
468
+ **kwargs,
469
+ ):
470
+ r"""
471
+ Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
472
+ Automatic1111 formats are supported).
473
+
474
+ Parameters:
475
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
476
+ Can be either one of the following or a list of them:
477
+
478
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
479
+ pretrained model hosted on the Hub.
480
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
481
+ inversion weights.
482
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
483
+ - A [torch state
484
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
485
+
486
+ token (`str` or `List[str]`, *optional*):
487
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
488
+ list, then `token` must also be a list of equal length.
489
+ weight_name (`str`, *optional*):
490
+ Name of a custom weight file. This should be used when:
491
+
492
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
493
+ name such as `text_inv.bin`.
494
+ - The saved textual inversion file is in the Automatic1111 format.
495
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
496
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
497
+ is not used.
498
+ force_download (`bool`, *optional*, defaults to `False`):
499
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
500
+ cached versions if they exist.
501
+ resume_download (`bool`, *optional*, defaults to `False`):
502
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
503
+ incompletely downloaded files are deleted.
504
+ proxies (`Dict[str, str]`, *optional*):
505
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
506
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
507
+ local_files_only (`bool`, *optional*, defaults to `False`):
508
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
509
+ won't be downloaded from the Hub.
510
+ use_auth_token (`str` or *bool*, *optional*):
511
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
512
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
513
+ revision (`str`, *optional*, defaults to `"main"`):
514
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
515
+ allowed by Git.
516
+ subfolder (`str`, *optional*, defaults to `""`):
517
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
518
+ mirror (`str`, *optional*):
519
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
520
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
521
+ information.
522
+
523
+ Example:
524
+
525
+ To load a textual inversion embedding vector in 🤗 Diffusers format:
526
+
527
+ ```py
528
+ from diffusers import StableDiffusionPipeline
529
+ import torch
530
+
531
+ model_id = "runwayml/stable-diffusion-v1-5"
532
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
533
+
534
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
535
+
536
+ prompt = "A <cat-toy> backpack"
537
+
538
+ image = pipe(prompt, num_inference_steps=50).images[0]
539
+ image.save("cat-backpack.png")
540
+ ```
541
+
542
+ To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
543
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
544
+ locally:
545
+
546
+ ```py
547
+ from diffusers import StableDiffusionPipeline
548
+ import torch
549
+
550
+ model_id = "runwayml/stable-diffusion-v1-5"
551
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
552
+
553
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
554
+
555
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
556
+
557
+ image = pipe(prompt, num_inference_steps=50).images[0]
558
+ image.save("character.png")
559
+ ```
560
+
561
+ """
562
+ if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
563
+ raise ValueError(
564
+ f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
565
+ f" `{self.load_textual_inversion.__name__}`"
566
+ )
567
+
568
+ if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
569
+ raise ValueError(
570
+ f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
571
+ f" `{self.load_textual_inversion.__name__}`"
572
+ )
573
+
574
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
575
+ force_download = kwargs.pop("force_download", False)
576
+ resume_download = kwargs.pop("resume_download", False)
577
+ proxies = kwargs.pop("proxies", None)
578
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
579
+ use_auth_token = kwargs.pop("use_auth_token", None)
580
+ revision = kwargs.pop("revision", None)
581
+ subfolder = kwargs.pop("subfolder", None)
582
+ weight_name = kwargs.pop("weight_name", None)
583
+ use_safetensors = kwargs.pop("use_safetensors", None)
584
+
585
+ if use_safetensors and not is_safetensors_available():
586
+ raise ValueError(
587
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
588
+ )
589
+
590
+ allow_pickle = False
591
+ if use_safetensors is None:
592
+ use_safetensors = is_safetensors_available()
593
+ allow_pickle = True
594
+
595
+ user_agent = {
596
+ "file_type": "text_inversion",
597
+ "framework": "pytorch",
598
+ }
599
+
600
+ if not isinstance(pretrained_model_name_or_path, list):
601
+ pretrained_model_name_or_paths = [pretrained_model_name_or_path]
602
+ else:
603
+ pretrained_model_name_or_paths = pretrained_model_name_or_path
604
+
605
+ if isinstance(token, str):
606
+ tokens = [token]
607
+ elif token is None:
608
+ tokens = [None] * len(pretrained_model_name_or_paths)
609
+ else:
610
+ tokens = token
611
+
612
+ if len(pretrained_model_name_or_paths) != len(tokens):
613
+ raise ValueError(
614
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
615
+ f"Make sure both lists have the same length."
616
+ )
617
+
618
+ valid_tokens = [t for t in tokens if t is not None]
619
+ if len(set(valid_tokens)) < len(valid_tokens):
620
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
621
+
622
+ token_ids_and_embeddings = []
623
+
624
+ for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
625
+ if not isinstance(pretrained_model_name_or_path, dict):
626
+ # 1. Load textual inversion file
627
+ model_file = None
628
+ # Let's first try to load .safetensors weights
629
+ if (use_safetensors and weight_name is None) or (
630
+ weight_name is not None and weight_name.endswith(".safetensors")
631
+ ):
632
+ try:
633
+ model_file = _get_model_file(
634
+ pretrained_model_name_or_path,
635
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
636
+ cache_dir=cache_dir,
637
+ force_download=force_download,
638
+ resume_download=resume_download,
639
+ proxies=proxies,
640
+ local_files_only=local_files_only,
641
+ use_auth_token=use_auth_token,
642
+ revision=revision,
643
+ subfolder=subfolder,
644
+ user_agent=user_agent,
645
+ )
646
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
647
+ except Exception as e:
648
+ if not allow_pickle:
649
+ raise e
650
+
651
+ model_file = None
652
+
653
+ if model_file is None:
654
+ model_file = _get_model_file(
655
+ pretrained_model_name_or_path,
656
+ weights_name=weight_name or TEXT_INVERSION_NAME,
657
+ cache_dir=cache_dir,
658
+ force_download=force_download,
659
+ resume_download=resume_download,
660
+ proxies=proxies,
661
+ local_files_only=local_files_only,
662
+ use_auth_token=use_auth_token,
663
+ revision=revision,
664
+ subfolder=subfolder,
665
+ user_agent=user_agent,
666
+ )
667
+ state_dict = torch.load(model_file, map_location="cpu")
668
+ else:
669
+ state_dict = pretrained_model_name_or_path
670
+
671
+ # 2. Load token and embedding correcly from file
672
+ loaded_token = None
673
+ if isinstance(state_dict, torch.Tensor):
674
+ if token is None:
675
+ raise ValueError(
676
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
677
+ )
678
+ embedding = state_dict
679
+ elif len(state_dict) == 1:
680
+ # diffusers
681
+ loaded_token, embedding = next(iter(state_dict.items()))
682
+ elif "string_to_param" in state_dict:
683
+ # A1111
684
+ loaded_token = state_dict["name"]
685
+ embedding = state_dict["string_to_param"]["*"]
686
+
687
+ if token is not None and loaded_token != token:
688
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
689
+ else:
690
+ token = loaded_token
691
+
692
+ embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
693
+
694
+ # 3. Make sure we don't mess up the tokenizer or text encoder
695
+ vocab = self.tokenizer.get_vocab()
696
+ if token in vocab:
697
+ raise ValueError(
698
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
699
+ )
700
+ elif f"{token}_1" in vocab:
701
+ multi_vector_tokens = [token]
702
+ i = 1
703
+ while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
704
+ multi_vector_tokens.append(f"{token}_{i}")
705
+ i += 1
706
+
707
+ raise ValueError(
708
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
709
+ )
710
+
711
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
712
+
713
+ if is_multi_vector:
714
+ tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
715
+ embeddings = [e for e in embedding] # noqa: C416
716
+ else:
717
+ tokens = [token]
718
+ embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
719
+
720
+ # add tokens and get ids
721
+ self.tokenizer.add_tokens(tokens)
722
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
723
+ token_ids_and_embeddings += zip(token_ids, embeddings)
724
+
725
+ logger.info(f"Loaded textual inversion embedding for {token}.")
726
+
727
+ # resize token embeddings and set all new embeddings
728
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
729
+ for token_id, embedding in token_ids_and_embeddings:
730
+ self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
731
+
732
+
733
+ class LoraLoaderMixin:
734
+ r"""
735
+ Load LoRA layers into [`UNet2DConditionModel`] and
736
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
737
+ """
738
+ text_encoder_name = TEXT_ENCODER_NAME
739
+ unet_name = UNET_NAME
740
+
741
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
742
+ r"""
743
+ Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
744
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
745
+
746
+ Parameters:
747
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
748
+ Can be either:
749
+
750
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
751
+ the Hub.
752
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
753
+ with [`ModelMixin.save_pretrained`].
754
+ - A [torch state
755
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
756
+
757
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
758
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
759
+ is not used.
760
+ force_download (`bool`, *optional*, defaults to `False`):
761
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
762
+ cached versions if they exist.
763
+ resume_download (`bool`, *optional*, defaults to `False`):
764
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
765
+ incompletely downloaded files are deleted.
766
+ proxies (`Dict[str, str]`, *optional*):
767
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
768
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
769
+ local_files_only (`bool`, *optional*, defaults to `False`):
770
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
771
+ won't be downloaded from the Hub.
772
+ use_auth_token (`str` or *bool*, *optional*):
773
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
774
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
775
+ revision (`str`, *optional*, defaults to `"main"`):
776
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
777
+ allowed by Git.
778
+ subfolder (`str`, *optional*, defaults to `""`):
779
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
780
+ mirror (`str`, *optional*):
781
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
782
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
783
+ information.
784
+
785
+ """
786
+ # Load the main state dict first which has the LoRA layers for either of
787
+ # UNet and text encoder or both.
788
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
789
+ force_download = kwargs.pop("force_download", False)
790
+ resume_download = kwargs.pop("resume_download", False)
791
+ proxies = kwargs.pop("proxies", None)
792
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
793
+ use_auth_token = kwargs.pop("use_auth_token", None)
794
+ revision = kwargs.pop("revision", None)
795
+ subfolder = kwargs.pop("subfolder", None)
796
+ weight_name = kwargs.pop("weight_name", None)
797
+ use_safetensors = kwargs.pop("use_safetensors", None)
798
+
799
+ # set lora scale to a reasonable default
800
+ self._lora_scale = 1.0
801
+
802
+ if use_safetensors and not is_safetensors_available():
803
+ raise ValueError(
804
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
805
+ )
806
+
807
+ allow_pickle = False
808
+ if use_safetensors is None:
809
+ use_safetensors = is_safetensors_available()
810
+ allow_pickle = True
811
+
812
+ user_agent = {
813
+ "file_type": "attn_procs_weights",
814
+ "framework": "pytorch",
815
+ }
816
+
817
+ model_file = None
818
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
819
+ # Let's first try to load .safetensors weights
820
+ if (use_safetensors and weight_name is None) or (
821
+ weight_name is not None and weight_name.endswith(".safetensors")
822
+ ):
823
+ try:
824
+ model_file = _get_model_file(
825
+ pretrained_model_name_or_path_or_dict,
826
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
827
+ cache_dir=cache_dir,
828
+ force_download=force_download,
829
+ resume_download=resume_download,
830
+ proxies=proxies,
831
+ local_files_only=local_files_only,
832
+ use_auth_token=use_auth_token,
833
+ revision=revision,
834
+ subfolder=subfolder,
835
+ user_agent=user_agent,
836
+ )
837
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
838
+ except IOError as e:
839
+ if not allow_pickle:
840
+ raise e
841
+ # try loading non-safetensors weights
842
+ pass
843
+ if model_file is None:
844
+ model_file = _get_model_file(
845
+ pretrained_model_name_or_path_or_dict,
846
+ weights_name=weight_name or LORA_WEIGHT_NAME,
847
+ cache_dir=cache_dir,
848
+ force_download=force_download,
849
+ resume_download=resume_download,
850
+ proxies=proxies,
851
+ local_files_only=local_files_only,
852
+ use_auth_token=use_auth_token,
853
+ revision=revision,
854
+ subfolder=subfolder,
855
+ user_agent=user_agent,
856
+ )
857
+ state_dict = torch.load(model_file, map_location="cpu")
858
+ else:
859
+ state_dict = pretrained_model_name_or_path_or_dict
860
+
861
+ # Convert kohya-ss Style LoRA attn procs to diffusers attn procs
862
+ network_alpha = None
863
+ if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
864
+ state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
865
+
866
+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
867
+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
868
+ # their prefixes.
869
+ keys = list(state_dict.keys())
870
+ if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
871
+ # Load the layers corresponding to UNet.
872
+ unet_keys = [k for k in keys if k.startswith(self.unet_name)]
873
+ logger.info(f"Loading {self.unet_name}.")
874
+ unet_lora_state_dict = {
875
+ k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
876
+ }
877
+ self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
878
+
879
+ # Load the layers corresponding to text encoder and make necessary adjustments.
880
+ text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
881
+ text_encoder_lora_state_dict = {
882
+ k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
883
+ }
884
+ if len(text_encoder_lora_state_dict) > 0:
885
+ logger.info(f"Loading {self.text_encoder_name}.")
886
+ attn_procs_text_encoder = self._load_text_encoder_attn_procs(
887
+ text_encoder_lora_state_dict, network_alpha=network_alpha
888
+ )
889
+ self._modify_text_encoder(attn_procs_text_encoder)
890
+
891
+ # save lora attn procs of text encoder so that it can be easily retrieved
892
+ self._text_encoder_lora_attn_procs = attn_procs_text_encoder
893
+
894
+ # Otherwise, we're dealing with the old format. This means the `state_dict` should only
895
+ # contain the module names of the `unet` as its keys WITHOUT any prefix.
896
+ elif not all(
897
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
898
+ ):
899
+ self.unet.load_attn_procs(state_dict)
900
+ warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
901
+ warnings.warn(warn_message)
902
+
903
+ @property
904
+ def lora_scale(self) -> float:
905
+ # property function that returns the lora scale which can be set at run time by the pipeline.
906
+ # if _lora_scale has not been set, return 1
907
+ return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
908
+
909
+ @property
910
+ def text_encoder_lora_attn_procs(self):
911
+ if hasattr(self, "_text_encoder_lora_attn_procs"):
912
+ return self._text_encoder_lora_attn_procs
913
+ return
914
+
915
+ def _remove_text_encoder_monkey_patch(self):
916
+ # Loop over the CLIPAttention module of text_encoder
917
+ for name, attn_module in self.text_encoder.named_modules():
918
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
919
+ # Loop over the LoRA layers
920
+ for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
921
+ # Retrieve the q/k/v/out projection of CLIPAttention
922
+ module = attn_module.get_submodule(text_encoder_attr)
923
+ if hasattr(module, "old_forward"):
924
+ # restore original `forward` to remove monkey-patch
925
+ module.forward = module.old_forward
926
+ delattr(module, "old_forward")
927
+
928
+ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
929
+ r"""
930
+ Monkey-patches the forward passes of attention modules of the text encoder.
931
+
932
+ Parameters:
933
+ attn_processors: Dict[str, `LoRAAttnProcessor`]:
934
+ A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
935
+ """
936
+
937
+ # First, remove any monkey-patch that might have been applied before
938
+ self._remove_text_encoder_monkey_patch()
939
+
940
+ # Loop over the CLIPAttention module of text_encoder
941
+ for name, attn_module in self.text_encoder.named_modules():
942
+ if name.endswith(TEXT_ENCODER_ATTN_MODULE):
943
+ # Loop over the LoRA layers
944
+ for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
945
+ # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
946
+ module = attn_module.get_submodule(text_encoder_attr)
947
+ lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
948
+
949
+ # save old_forward to module that can be used to remove monkey-patch
950
+ old_forward = module.old_forward = module.forward
951
+
952
+ # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
953
+ # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
954
+ def make_new_forward(old_forward, lora_layer):
955
+ def new_forward(x):
956
+ result = old_forward(x) + self.lora_scale * lora_layer(x)
957
+ return result
958
+
959
+ return new_forward
960
+
961
+ # Monkey-patch.
962
+ module.forward = make_new_forward(old_forward, lora_layer)
963
+
964
+ @property
965
+ def _lora_attn_processor_attr_to_text_encoder_attr(self):
966
+ return {
967
+ "to_q_lora": "q_proj",
968
+ "to_k_lora": "k_proj",
969
+ "to_v_lora": "v_proj",
970
+ "to_out_lora": "out_proj",
971
+ }
972
+
973
+ def _load_text_encoder_attn_procs(
974
+ self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
975
+ ):
976
+ r"""
977
+ Load pretrained attention processor layers for
978
+ [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
979
+
980
+ <Tip warning={true}>
981
+
982
+ This function is experimental and might change in the future.
983
+
984
+ </Tip>
985
+
986
+ Parameters:
987
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
988
+ Can be either:
989
+
990
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
991
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
992
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
993
+ `./my_model_directory/`.
994
+ - A [torch state
995
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
996
+
997
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
998
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
999
+ standard cache should not be used.
1000
+ force_download (`bool`, *optional*, defaults to `False`):
1001
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1002
+ cached versions if they exist.
1003
+ resume_download (`bool`, *optional*, defaults to `False`):
1004
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
1005
+ file exists.
1006
+ proxies (`Dict[str, str]`, *optional*):
1007
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
1008
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1009
+ local_files_only (`bool`, *optional*, defaults to `False`):
1010
+ Whether or not to only look at local files (i.e., do not try to download the model).
1011
+ use_auth_token (`str` or *bool*, *optional*):
1012
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
1013
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
1014
+ revision (`str`, *optional*, defaults to `"main"`):
1015
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
1016
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
1017
+ identifier allowed by git.
1018
+ subfolder (`str`, *optional*, defaults to `""`):
1019
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
1020
+ huggingface.co or downloaded locally), you can specify the folder name here.
1021
+ mirror (`str`, *optional*):
1022
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
1023
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
1024
+ Please refer to the mirror site for more information.
1025
+
1026
+ Returns:
1027
+ `Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
1028
+ [`LoRAAttnProcessor`].
1029
+
1030
+ <Tip>
1031
+
1032
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
1033
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
1034
+
1035
+ </Tip>
1036
+ """
1037
+
1038
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1039
+ force_download = kwargs.pop("force_download", False)
1040
+ resume_download = kwargs.pop("resume_download", False)
1041
+ proxies = kwargs.pop("proxies", None)
1042
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1043
+ use_auth_token = kwargs.pop("use_auth_token", None)
1044
+ revision = kwargs.pop("revision", None)
1045
+ subfolder = kwargs.pop("subfolder", None)
1046
+ weight_name = kwargs.pop("weight_name", None)
1047
+ use_safetensors = kwargs.pop("use_safetensors", None)
1048
+ network_alpha = kwargs.pop("network_alpha", None)
1049
+
1050
+ if use_safetensors and not is_safetensors_available():
1051
+ raise ValueError(
1052
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
1053
+ )
1054
+
1055
+ allow_pickle = False
1056
+ if use_safetensors is None:
1057
+ use_safetensors = is_safetensors_available()
1058
+ allow_pickle = True
1059
+
1060
+ user_agent = {
1061
+ "file_type": "attn_procs_weights",
1062
+ "framework": "pytorch",
1063
+ }
1064
+
1065
+ model_file = None
1066
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
1067
+ # Let's first try to load .safetensors weights
1068
+ if (use_safetensors and weight_name is None) or (
1069
+ weight_name is not None and weight_name.endswith(".safetensors")
1070
+ ):
1071
+ try:
1072
+ model_file = _get_model_file(
1073
+ pretrained_model_name_or_path_or_dict,
1074
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
1075
+ cache_dir=cache_dir,
1076
+ force_download=force_download,
1077
+ resume_download=resume_download,
1078
+ proxies=proxies,
1079
+ local_files_only=local_files_only,
1080
+ use_auth_token=use_auth_token,
1081
+ revision=revision,
1082
+ subfolder=subfolder,
1083
+ user_agent=user_agent,
1084
+ )
1085
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
1086
+ except IOError as e:
1087
+ if not allow_pickle:
1088
+ raise e
1089
+ # try loading non-safetensors weights
1090
+ pass
1091
+ if model_file is None:
1092
+ model_file = _get_model_file(
1093
+ pretrained_model_name_or_path_or_dict,
1094
+ weights_name=weight_name or LORA_WEIGHT_NAME,
1095
+ cache_dir=cache_dir,
1096
+ force_download=force_download,
1097
+ resume_download=resume_download,
1098
+ proxies=proxies,
1099
+ local_files_only=local_files_only,
1100
+ use_auth_token=use_auth_token,
1101
+ revision=revision,
1102
+ subfolder=subfolder,
1103
+ user_agent=user_agent,
1104
+ )
1105
+ state_dict = torch.load(model_file, map_location="cpu")
1106
+ else:
1107
+ state_dict = pretrained_model_name_or_path_or_dict
1108
+
1109
+ # fill attn processors
1110
+ attn_processors = {}
1111
+
1112
+ is_lora = all("lora" in k for k in state_dict.keys())
1113
+
1114
+ if is_lora:
1115
+ lora_grouped_dict = defaultdict(dict)
1116
+ for key, value in state_dict.items():
1117
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
1118
+ lora_grouped_dict[attn_processor_key][sub_key] = value
1119
+
1120
+ for key, value_dict in lora_grouped_dict.items():
1121
+ rank = value_dict["to_k_lora.down.weight"].shape[0]
1122
+ cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
1123
+ hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
1124
+
1125
+ attn_processor_class = (
1126
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
1127
+ )
1128
+ attn_processors[key] = attn_processor_class(
1129
+ hidden_size=hidden_size,
1130
+ cross_attention_dim=cross_attention_dim,
1131
+ rank=rank,
1132
+ network_alpha=network_alpha,
1133
+ )
1134
+ attn_processors[key].load_state_dict(value_dict)
1135
+
1136
+ else:
1137
+ raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
1138
+
1139
+ # set correct dtype & device
1140
+ attn_processors = {
1141
+ k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
1142
+ }
1143
+ return attn_processors
1144
+
1145
+ @classmethod
1146
+ def save_lora_weights(
1147
+ self,
1148
+ save_directory: Union[str, os.PathLike],
1149
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1150
+ text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
1151
+ is_main_process: bool = True,
1152
+ weight_name: str = None,
1153
+ save_function: Callable = None,
1154
+ safe_serialization: bool = False,
1155
+ ):
1156
+ r"""
1157
+ Save the LoRA parameters corresponding to the UNet and text encoder.
1158
+
1159
+ Arguments:
1160
+ save_directory (`str` or `os.PathLike`):
1161
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
1162
+ unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
1163
+ State dict of the LoRA layers corresponding to the UNet.
1164
+ text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
1165
+ State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
1166
+ encoder LoRA state dict because it comes 🤗 Transformers.
1167
+ is_main_process (`bool`, *optional*, defaults to `True`):
1168
+ Whether the process calling this is the main process or not. Useful during distributed training and you
1169
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
1170
+ process to avoid race conditions.
1171
+ save_function (`Callable`):
1172
+ The function to use to save the state dictionary. Useful during distributed training when you need to
1173
+ replace `torch.save` with another method. Can be configured with the environment variable
1174
+ `DIFFUSERS_SAVE_MODE`.
1175
+ """
1176
+ if os.path.isfile(save_directory):
1177
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1178
+ return
1179
+
1180
+ if save_function is None:
1181
+ if safe_serialization:
1182
+
1183
+ def save_function(weights, filename):
1184
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
1185
+
1186
+ else:
1187
+ save_function = torch.save
1188
+
1189
+ os.makedirs(save_directory, exist_ok=True)
1190
+
1191
+ # Create a flat dictionary.
1192
+ state_dict = {}
1193
+ if unet_lora_layers is not None:
1194
+ weights = (
1195
+ unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
1196
+ )
1197
+
1198
+ unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
1199
+ state_dict.update(unet_lora_state_dict)
1200
+
1201
+ if text_encoder_lora_layers is not None:
1202
+ weights = (
1203
+ text_encoder_lora_layers.state_dict()
1204
+ if isinstance(text_encoder_lora_layers, torch.nn.Module)
1205
+ else text_encoder_lora_layers
1206
+ )
1207
+
1208
+ text_encoder_lora_state_dict = {
1209
+ f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
1210
+ }
1211
+ state_dict.update(text_encoder_lora_state_dict)
1212
+
1213
+ # Save the model
1214
+ if weight_name is None:
1215
+ if safe_serialization:
1216
+ weight_name = LORA_WEIGHT_NAME_SAFE
1217
+ else:
1218
+ weight_name = LORA_WEIGHT_NAME
1219
+
1220
+ save_function(state_dict, os.path.join(save_directory, weight_name))
1221
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
1222
+
1223
+ def _convert_kohya_lora_to_diffusers(self, state_dict):
1224
+ unet_state_dict = {}
1225
+ te_state_dict = {}
1226
+ network_alpha = None
1227
+
1228
+ for key, value in state_dict.items():
1229
+ if "lora_down" in key:
1230
+ lora_name = key.split(".")[0]
1231
+ lora_name_up = lora_name + ".lora_up.weight"
1232
+ lora_name_alpha = lora_name + ".alpha"
1233
+ if lora_name_alpha in state_dict:
1234
+ alpha = state_dict[lora_name_alpha].item()
1235
+ if network_alpha is None:
1236
+ network_alpha = alpha
1237
+ elif network_alpha != alpha:
1238
+ raise ValueError("Network alpha is not consistent")
1239
+
1240
+ if lora_name.startswith("lora_unet_"):
1241
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1242
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
1243
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
1244
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
1245
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
1246
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
1247
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
1248
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
1249
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1250
+ if "transformer_blocks" in diffusers_name:
1251
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
1252
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
1253
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
1254
+ unet_state_dict[diffusers_name] = value
1255
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1256
+ elif lora_name.startswith("lora_te_"):
1257
+ diffusers_name = key.replace("lora_te_", "").replace("_", ".")
1258
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
1259
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
1260
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
1261
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
1262
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
1263
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
1264
+ if "self_attn" in diffusers_name:
1265
+ te_state_dict[diffusers_name] = value
1266
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1267
+
1268
+ unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
1269
+ te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
1270
+ new_state_dict = {**unet_state_dict, **te_state_dict}
1271
+ return new_state_dict, network_alpha
1272
+
1273
+
1274
+ class FromCkptMixin:
1275
+ """
1276
+ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
1277
+ """
1278
+
1279
+ @classmethod
1280
+ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
1281
+ r"""
1282
+ Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
1283
+ is set in evaluation mode (`model.eval()`) by default.
1284
+
1285
+ Parameters:
1286
+ pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
1287
+ Can be either:
1288
+ - A link to the `.ckpt` file (for example
1289
+ `"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
1290
+ - A path to a *file* containing all pipeline weights.
1291
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1292
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1293
+ dtype is automatically derived from the model's weights.
1294
+ force_download (`bool`, *optional*, defaults to `False`):
1295
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1296
+ cached versions if they exist.
1297
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1298
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1299
+ is not used.
1300
+ resume_download (`bool`, *optional*, defaults to `False`):
1301
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1302
+ incompletely downloaded files are deleted.
1303
+ proxies (`Dict[str, str]`, *optional*):
1304
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1305
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1306
+ local_files_only (`bool`, *optional*, defaults to `False`):
1307
+ Whether to only load local model weights and configuration files or not. If set to True, the model
1308
+ won't be downloaded from the Hub.
1309
+ use_auth_token (`str` or *bool*, *optional*):
1310
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1311
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1312
+ revision (`str`, *optional*, defaults to `"main"`):
1313
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1314
+ allowed by Git.
1315
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1316
+ If set to `None`, the safetensors weights are downloaded if they're available **and** if the
1317
+ safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
1318
+ weights. If set to `False`, safetensors weights are not loaded.
1319
+ extract_ema (`bool`, *optional*, defaults to `False`):
1320
+ Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
1321
+ higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
1322
+ upcast_attention (`bool`, *optional*, defaults to `None`):
1323
+ Whether the attention computation should always be upcasted.
1324
+ image_size (`int`, *optional*, defaults to 512):
1325
+ The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
1326
+ Diffusion v2 base model. Use 768 for Stable Diffusion v2.
1327
+ prediction_type (`str`, *optional*):
1328
+ The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
1329
+ the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
1330
+ num_in_channels (`int`, *optional*, defaults to `None`):
1331
+ The number of input channels. If `None`, it will be automatically inferred.
1332
+ scheduler_type (`str`, *optional*, defaults to `"pndm"`):
1333
+ Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
1334
+ "ddim"]`.
1335
+ load_safety_checker (`bool`, *optional*, defaults to `True`):
1336
+ Whether to load the safety checker or not.
1337
+ text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
1338
+ An instance of
1339
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
1340
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
1341
+ variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
1342
+ needed.
1343
+ tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
1344
+ An instance of
1345
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
1346
+ to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
1347
+ itself, if needed.
1348
+ kwargs (remaining dictionary of keyword arguments, *optional*):
1349
+ Can be used to overwrite load and saveable variables (for example the pipeline components of the
1350
+ specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
1351
+ method. See example below for more information.
1352
+
1353
+ Examples:
1354
+
1355
+ ```py
1356
+ >>> from diffusers import StableDiffusionPipeline
1357
+
1358
+ >>> # Download pipeline from huggingface.co and cache.
1359
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1360
+ ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
1361
+ ... )
1362
+
1363
+ >>> # Download pipeline from local file
1364
+ >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
1365
+ >>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
1366
+
1367
+ >>> # Enable float16 and move to GPU
1368
+ >>> pipeline = StableDiffusionPipeline.from_ckpt(
1369
+ ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
1370
+ ... torch_dtype=torch.float16,
1371
+ ... )
1372
+ >>> pipeline.to("cuda")
1373
+ ```
1374
+ """
1375
+ # import here to avoid circular dependency
1376
+ from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
1377
+
1378
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1379
+ resume_download = kwargs.pop("resume_download", False)
1380
+ force_download = kwargs.pop("force_download", False)
1381
+ proxies = kwargs.pop("proxies", None)
1382
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1383
+ use_auth_token = kwargs.pop("use_auth_token", None)
1384
+ revision = kwargs.pop("revision", None)
1385
+ extract_ema = kwargs.pop("extract_ema", False)
1386
+ image_size = kwargs.pop("image_size", 512)
1387
+ scheduler_type = kwargs.pop("scheduler_type", "pndm")
1388
+ num_in_channels = kwargs.pop("num_in_channels", None)
1389
+ upcast_attention = kwargs.pop("upcast_attention", None)
1390
+ load_safety_checker = kwargs.pop("load_safety_checker", True)
1391
+ prediction_type = kwargs.pop("prediction_type", None)
1392
+ text_encoder = kwargs.pop("text_encoder", None)
1393
+ tokenizer = kwargs.pop("tokenizer", None)
1394
+
1395
+ torch_dtype = kwargs.pop("torch_dtype", None)
1396
+
1397
+ use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
1398
+
1399
+ pipeline_name = cls.__name__
1400
+ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
1401
+ from_safetensors = file_extension == "safetensors"
1402
+
1403
+ if from_safetensors and use_safetensors is False:
1404
+ raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
1405
+
1406
+ # TODO: For now we only support stable diffusion
1407
+ stable_unclip = None
1408
+ model_type = None
1409
+ controlnet = False
1410
+
1411
+ if pipeline_name == "StableDiffusionControlNetPipeline":
1412
+ # Model type will be inferred from the checkpoint.
1413
+ controlnet = True
1414
+ elif "StableDiffusion" in pipeline_name:
1415
+ # Model type will be inferred from the checkpoint.
1416
+ pass
1417
+ elif pipeline_name == "StableUnCLIPPipeline":
1418
+ model_type = "FrozenOpenCLIPEmbedder"
1419
+ stable_unclip = "txt2img"
1420
+ elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
1421
+ model_type = "FrozenOpenCLIPEmbedder"
1422
+ stable_unclip = "img2img"
1423
+ elif pipeline_name == "PaintByExamplePipeline":
1424
+ model_type = "PaintByExample"
1425
+ elif pipeline_name == "LDMTextToImagePipeline":
1426
+ model_type = "LDMTextToImage"
1427
+ else:
1428
+ raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
1429
+
1430
+ # remove huggingface url
1431
+ for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
1432
+ if pretrained_model_link_or_path.startswith(prefix):
1433
+ pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
1434
+
1435
+ # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
1436
+ ckpt_path = Path(pretrained_model_link_or_path)
1437
+ if not ckpt_path.is_file():
1438
+ # get repo_id and (potentially nested) file path of ckpt in repo
1439
+ repo_id = "/".join(ckpt_path.parts[:2])
1440
+ file_path = "/".join(ckpt_path.parts[2:])
1441
+
1442
+ if file_path.startswith("blob/"):
1443
+ file_path = file_path[len("blob/") :]
1444
+
1445
+ if file_path.startswith("main/"):
1446
+ file_path = file_path[len("main/") :]
1447
+
1448
+ pretrained_model_link_or_path = hf_hub_download(
1449
+ repo_id,
1450
+ filename=file_path,
1451
+ cache_dir=cache_dir,
1452
+ resume_download=resume_download,
1453
+ proxies=proxies,
1454
+ local_files_only=local_files_only,
1455
+ use_auth_token=use_auth_token,
1456
+ revision=revision,
1457
+ force_download=force_download,
1458
+ )
1459
+
1460
+ pipe = download_from_original_stable_diffusion_ckpt(
1461
+ pretrained_model_link_or_path,
1462
+ pipeline_class=cls,
1463
+ model_type=model_type,
1464
+ stable_unclip=stable_unclip,
1465
+ controlnet=controlnet,
1466
+ from_safetensors=from_safetensors,
1467
+ extract_ema=extract_ema,
1468
+ image_size=image_size,
1469
+ scheduler_type=scheduler_type,
1470
+ num_in_channels=num_in_channels,
1471
+ upcast_attention=upcast_attention,
1472
+ load_safety_checker=load_safety_checker,
1473
+ prediction_type=prediction_type,
1474
+ text_encoder=text_encoder,
1475
+ tokenizer=tokenizer,
1476
+ )
1477
+
1478
+ if torch_dtype is not None:
1479
+ pipe.to(torch_dtype=torch_dtype)
1480
+
1481
+ return pipe
diffusers/models/modeling_utils.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ import re
21
+ from functools import partial
22
+ from typing import Any, Callable, List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import Tensor, device, nn
26
+
27
+ from ..utils.constants import (
28
+ CONFIG_NAME,
29
+ DIFFUSERS_CACHE,
30
+ FLAX_WEIGHTS_NAME,
31
+ SAFETENSORS_WEIGHTS_NAME,
32
+ WEIGHTS_NAME
33
+ )
34
+ from ..utils.hub_utils import (
35
+ HF_HUB_OFFLINE,
36
+ _add_variant,
37
+ _get_model_file
38
+ )
39
+ from ..utils.deprecation_utils import deprecate
40
+ from ..utils.import_utils import (
41
+ is_accelerate_available,
42
+ is_safetensors_available,
43
+ is_torch_version
44
+ )
45
+ from ..utils.logging import get_logger
46
+
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ if is_torch_version(">=", "1.9.0"):
51
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
52
+ else:
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
54
+
55
+
56
+ if is_accelerate_available():
57
+ import accelerate
58
+ from accelerate.utils import set_module_tensor_to_device
59
+ from accelerate.utils.versions import is_torch_version
60
+
61
+ if is_safetensors_available():
62
+ import safetensors
63
+
64
+
65
+ def get_parameter_device(parameter: torch.nn.Module):
66
+ try:
67
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
68
+ return next(parameters_and_buffers).device
69
+ except StopIteration:
70
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
71
+
72
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
73
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
74
+ return tuples
75
+
76
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
77
+ first_tuple = next(gen)
78
+ return first_tuple[1].device
79
+
80
+
81
+ def get_parameter_dtype(parameter: torch.nn.Module):
82
+ try:
83
+ params = tuple(parameter.parameters())
84
+ if len(params) > 0:
85
+ return params[0].dtype
86
+
87
+ buffers = tuple(parameter.buffers())
88
+ if len(buffers) > 0:
89
+ return buffers[0].dtype
90
+
91
+ except StopIteration:
92
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
93
+
94
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
95
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
96
+ return tuples
97
+
98
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
99
+ first_tuple = next(gen)
100
+ return first_tuple[1].dtype
101
+
102
+
103
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
104
+ """
105
+ Reads a checkpoint file, returning properly formatted errors if they arise.
106
+ """
107
+ try:
108
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
109
+ return torch.load(checkpoint_file, map_location="cpu")
110
+ else:
111
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
112
+ except Exception as e:
113
+ try:
114
+ with open(checkpoint_file) as f:
115
+ if f.read().startswith("version"):
116
+ raise OSError(
117
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
118
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
119
+ "you cloned."
120
+ )
121
+ else:
122
+ raise ValueError(
123
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
124
+ "model. Make sure you have saved the model properly."
125
+ ) from e
126
+ except (UnicodeDecodeError, ValueError):
127
+ raise OSError(
128
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
129
+ f"at '{checkpoint_file}'. "
130
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
131
+ )
132
+
133
+
134
+ def _load_state_dict_into_model(model_to_load, state_dict):
135
+ # Convert old format to new format if needed from a PyTorch state_dict
136
+ # copy state_dict so _load_from_state_dict can modify it
137
+ state_dict = state_dict.copy()
138
+ error_msgs = []
139
+
140
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
141
+ # so we need to apply the function recursively.
142
+ def load(module: torch.nn.Module, prefix=""):
143
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
144
+ module._load_from_state_dict(*args)
145
+
146
+ for name, child in module._modules.items():
147
+ if child is not None:
148
+ load(child, prefix + name + ".")
149
+
150
+ load(model_to_load)
151
+
152
+ return error_msgs
153
+
154
+
155
+ class ModelMixin(torch.nn.Module):
156
+ r"""
157
+ Base class for all models.
158
+
159
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
160
+ and saving models.
161
+
162
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
163
+ [`~models.ModelMixin.save_pretrained`].
164
+ """
165
+ config_name = CONFIG_NAME
166
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
167
+ _supports_gradient_checkpointing = False
168
+ _keys_to_ignore_on_load_unexpected = None
169
+
170
+ def __init__(self):
171
+ super().__init__()
172
+
173
+ def __getattr__(self, name: str) -> Any:
174
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
175
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
176
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
177
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
178
+ """
179
+
180
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
181
+ is_attribute = name in self.__dict__
182
+
183
+ if is_in_config and not is_attribute:
184
+ deprecation_message = (
185
+ f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. "
186
+ f"Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
187
+ )
188
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
189
+ return self._internal_dict[name]
190
+
191
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
192
+ return super().__getattr__(name)
193
+
194
+ @property
195
+ def is_gradient_checkpointing(self) -> bool:
196
+ """
197
+ Whether gradient checkpointing is activated for this model or not.
198
+
199
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
200
+ activations".
201
+ """
202
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
203
+
204
+ def enable_gradient_checkpointing(self):
205
+ """
206
+ Activates gradient checkpointing for the current model.
207
+
208
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
209
+ activations".
210
+ """
211
+ if not self._supports_gradient_checkpointing:
212
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
213
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
214
+
215
+ def disable_gradient_checkpointing(self):
216
+ """
217
+ Deactivates gradient checkpointing for the current model.
218
+
219
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
220
+ activations".
221
+ """
222
+ if self._supports_gradient_checkpointing:
223
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
224
+
225
+ def set_use_memory_efficient_attention_xformers(
226
+ self, valid: bool, attention_op: Optional[Callable] = None
227
+ ) -> None:
228
+ # Recursively walk through all the children.
229
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
230
+ # gets the message
231
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
232
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
233
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
234
+
235
+ for child in module.children():
236
+ fn_recursive_set_mem_eff(child)
237
+
238
+ for module in self.children():
239
+ if isinstance(module, torch.nn.Module):
240
+ fn_recursive_set_mem_eff(module)
241
+
242
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
243
+ r"""
244
+ Enable memory efficient attention as implemented in xformers.
245
+
246
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
247
+ time. Speed up at training time is not guaranteed.
248
+
249
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
250
+ is used.
251
+
252
+ Parameters:
253
+ attention_op (`Callable`, *optional*):
254
+ Override the default `None` operator for use as `op` argument to the
255
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
256
+ function of xFormers.
257
+
258
+ Examples:
259
+
260
+ ```py
261
+ >>> import torch
262
+ >>> from diffusers import UNet2DConditionModel
263
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
264
+
265
+ >>> model = UNet2DConditionModel.from_pretrained(
266
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
267
+ ... )
268
+ >>> model = model.to("cuda")
269
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
270
+ ```
271
+ """
272
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
273
+
274
+ def disable_xformers_memory_efficient_attention(self):
275
+ r"""
276
+ Disable memory efficient attention as implemented in xformers.
277
+ """
278
+ self.set_use_memory_efficient_attention_xformers(False)
279
+
280
+ def save_pretrained(
281
+ self,
282
+ save_directory: Union[str, os.PathLike],
283
+ is_main_process: bool = True,
284
+ save_function: Callable = None,
285
+ safe_serialization: bool = False,
286
+ variant: Optional[str] = None,
287
+ ):
288
+ """
289
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
290
+ `[`~models.ModelMixin.from_pretrained`]` class method.
291
+
292
+ Arguments:
293
+ save_directory (`str` or `os.PathLike`):
294
+ Directory to which to save. Will be created if it doesn't exist.
295
+ is_main_process (`bool`, *optional*, defaults to `True`):
296
+ Whether the process calling this is the main process or not. Useful when in distributed training like
297
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
298
+ the main process to avoid race conditions.
299
+ save_function (`Callable`):
300
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
301
+ need to replace `torch.save` by another method. Can be configured with the environment variable
302
+ `DIFFUSERS_SAVE_MODE`.
303
+ safe_serialization (`bool`, *optional*, defaults to `False`):
304
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
305
+ variant (`str`, *optional*):
306
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
307
+ """
308
+ if safe_serialization and not is_safetensors_available():
309
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
310
+
311
+ if os.path.isfile(save_directory):
312
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
313
+ return
314
+
315
+ os.makedirs(save_directory, exist_ok=True)
316
+
317
+ model_to_save = self
318
+
319
+ # Attach architecture to the config
320
+ # Save the config
321
+ if is_main_process:
322
+ model_to_save.save_config(save_directory)
323
+
324
+ # Save the model
325
+ state_dict = model_to_save.state_dict()
326
+
327
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
328
+ weights_name = _add_variant(weights_name, variant)
329
+
330
+ # Save the model
331
+ if safe_serialization:
332
+ safetensors.torch.save_file(
333
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
334
+ )
335
+ else:
336
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
337
+
338
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
339
+
340
+ @classmethod
341
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
342
+ r"""
343
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
344
+
345
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
346
+ the model, you should first set it back in training mode with `model.train()`.
347
+
348
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
349
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
350
+ task.
351
+
352
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
353
+ weights are discarded.
354
+
355
+ Parameters:
356
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
357
+ Can be either:
358
+
359
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
360
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
361
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
362
+ `./my_model_directory/`.
363
+
364
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
365
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
366
+ standard cache should not be used.
367
+ torch_dtype (`str` or `torch.dtype`, *optional*):
368
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
369
+ will be automatically derived from the model's weights.
370
+ force_download (`bool`, *optional*, defaults to `False`):
371
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
372
+ cached versions if they exist.
373
+ resume_download (`bool`, *optional*, defaults to `False`):
374
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
375
+ file exists.
376
+ proxies (`Dict[str, str]`, *optional*):
377
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
378
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
379
+ output_loading_info(`bool`, *optional*, defaults to `False`):
380
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
381
+ local_files_only(`bool`, *optional*, defaults to `False`):
382
+ Whether or not to only look at local files (i.e., do not try to download the model).
383
+ use_auth_token (`str` or *bool*, *optional*):
384
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
385
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
386
+ revision (`str`, *optional*, defaults to `"main"`):
387
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
388
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
389
+ identifier allowed by git.
390
+ from_flax (`bool`, *optional*, defaults to `False`):
391
+ Load the model weights from a Flax checkpoint save file.
392
+ subfolder (`str`, *optional*, defaults to `""`):
393
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
394
+ huggingface.co or downloaded locally), you can specify the folder name here.
395
+
396
+ mirror (`str`, *optional*):
397
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
398
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
399
+ Please refer to the mirror site for more information.
400
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
401
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
402
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
403
+ same device.
404
+
405
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
406
+ more information about each option see [designing a device
407
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
408
+ max_memory (`Dict`, *optional*):
409
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
410
+ GPU and the available CPU RAM if unset.
411
+ offload_folder (`str` or `os.PathLike`, *optional*):
412
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
413
+ offload_state_dict (`bool`, *optional*):
414
+ If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
415
+ RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
416
+ `True` when there is some disk offload.
417
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
418
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
419
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
420
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
421
+ setting this argument to `True` will raise an error.
422
+ variant (`str`, *optional*):
423
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
424
+ ignored when using `from_flax`.
425
+ use_safetensors (`bool`, *optional*, defaults to `None`):
426
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
427
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
428
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
429
+
430
+ <Tip>
431
+
432
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
433
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
434
+
435
+ </Tip>
436
+
437
+ <Tip>
438
+
439
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
440
+ this method in a firewalled environment.
441
+
442
+ </Tip>
443
+
444
+ """
445
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
446
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
447
+ force_download = kwargs.pop("force_download", False)
448
+ from_flax = kwargs.pop("from_flax", False)
449
+ resume_download = kwargs.pop("resume_download", False)
450
+ proxies = kwargs.pop("proxies", None)
451
+ output_loading_info = kwargs.pop("output_loading_info", False)
452
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
453
+ use_auth_token = kwargs.pop("use_auth_token", None)
454
+ revision = kwargs.pop("revision", None)
455
+ torch_dtype = kwargs.pop("torch_dtype", None)
456
+ subfolder = kwargs.pop("subfolder", None)
457
+ device_map = kwargs.pop("device_map", None)
458
+ max_memory = kwargs.pop("max_memory", None)
459
+ offload_folder = kwargs.pop("offload_folder", None)
460
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
461
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
462
+ variant = kwargs.pop("variant", None)
463
+ use_safetensors = kwargs.pop("use_safetensors", None)
464
+
465
+ if use_safetensors and not is_safetensors_available():
466
+ raise ValueError(
467
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
468
+ )
469
+
470
+ allow_pickle = False
471
+ if use_safetensors is None:
472
+ use_safetensors = is_safetensors_available()
473
+ allow_pickle = True
474
+
475
+ if low_cpu_mem_usage and not is_accelerate_available():
476
+ low_cpu_mem_usage = False
477
+ logger.warning(
478
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
479
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
480
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
481
+ " install accelerate\n```\n."
482
+ )
483
+
484
+ if device_map is not None and not is_accelerate_available():
485
+ raise NotImplementedError(
486
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
487
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
488
+ )
489
+
490
+ # Check if we can handle device_map and dispatching the weights
491
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
492
+ raise NotImplementedError(
493
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
494
+ " `device_map=None`."
495
+ )
496
+
497
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
498
+ raise NotImplementedError(
499
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
500
+ " `low_cpu_mem_usage=False`."
501
+ )
502
+
503
+ if low_cpu_mem_usage is False and device_map is not None:
504
+ raise ValueError(
505
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
506
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
507
+ )
508
+
509
+ # Load config if we don't provide a configuration
510
+ config_path = pretrained_model_name_or_path
511
+
512
+ user_agent = {
513
+ "diffusers": __version__,
514
+ "file_type": "model",
515
+ "framework": "pytorch",
516
+ }
517
+
518
+ # load config
519
+ config, unused_kwargs, commit_hash = cls.load_config(
520
+ config_path,
521
+ cache_dir=cache_dir,
522
+ return_unused_kwargs=True,
523
+ return_commit_hash=True,
524
+ force_download=force_download,
525
+ resume_download=resume_download,
526
+ proxies=proxies,
527
+ local_files_only=local_files_only,
528
+ use_auth_token=use_auth_token,
529
+ revision=revision,
530
+ subfolder=subfolder,
531
+ device_map=device_map,
532
+ max_memory=max_memory,
533
+ offload_folder=offload_folder,
534
+ offload_state_dict=offload_state_dict,
535
+ user_agent=user_agent,
536
+ **kwargs,
537
+ )
538
+
539
+ # load model
540
+ model_file = None
541
+ if from_flax:
542
+ model_file = _get_model_file(
543
+ pretrained_model_name_or_path,
544
+ weights_name=FLAX_WEIGHTS_NAME,
545
+ cache_dir=cache_dir,
546
+ force_download=force_download,
547
+ resume_download=resume_download,
548
+ proxies=proxies,
549
+ local_files_only=local_files_only,
550
+ use_auth_token=use_auth_token,
551
+ revision=revision,
552
+ subfolder=subfolder,
553
+ user_agent=user_agent,
554
+ commit_hash=commit_hash,
555
+ )
556
+ model = cls.from_config(config, **unused_kwargs)
557
+
558
+ # Convert the weights
559
+ from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
560
+
561
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
562
+ else:
563
+ if use_safetensors:
564
+ try:
565
+ model_file = _get_model_file(
566
+ pretrained_model_name_or_path,
567
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
568
+ cache_dir=cache_dir,
569
+ force_download=force_download,
570
+ resume_download=resume_download,
571
+ proxies=proxies,
572
+ local_files_only=local_files_only,
573
+ use_auth_token=use_auth_token,
574
+ revision=revision,
575
+ subfolder=subfolder,
576
+ user_agent=user_agent,
577
+ commit_hash=commit_hash,
578
+ )
579
+ except IOError as e:
580
+ if not allow_pickle:
581
+ raise e
582
+ pass
583
+ if model_file is None:
584
+ model_file = _get_model_file(
585
+ pretrained_model_name_or_path,
586
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
587
+ cache_dir=cache_dir,
588
+ force_download=force_download,
589
+ resume_download=resume_download,
590
+ proxies=proxies,
591
+ local_files_only=local_files_only,
592
+ use_auth_token=use_auth_token,
593
+ revision=revision,
594
+ subfolder=subfolder,
595
+ user_agent=user_agent,
596
+ commit_hash=commit_hash,
597
+ )
598
+
599
+ if low_cpu_mem_usage:
600
+ # Instantiate model with empty weights
601
+ with accelerate.init_empty_weights():
602
+ model = cls.from_config(config, **unused_kwargs)
603
+
604
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
605
+ if device_map is None:
606
+ param_device = "cpu"
607
+ state_dict = load_state_dict(model_file, variant=variant)
608
+ model._convert_deprecated_attention_blocks(state_dict)
609
+ # move the params from meta device to cpu
610
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
611
+ if len(missing_keys) > 0:
612
+ raise ValueError(
613
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
614
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
615
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
616
+ " those weights or else make sure your checkpoint file is correct."
617
+ )
618
+ unexpected_keys = []
619
+
620
+ empty_state_dict = model.state_dict()
621
+ for param_name, param in state_dict.items():
622
+ accepts_dtype = "dtype" in set(
623
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
624
+ )
625
+
626
+ if param_name not in empty_state_dict:
627
+ unexpected_keys.append(param_name)
628
+ continue
629
+
630
+ if empty_state_dict[param_name].shape != param.shape:
631
+ raise ValueError(
632
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
633
+ )
634
+
635
+ if accepts_dtype:
636
+ set_module_tensor_to_device(
637
+ model, param_name, param_device, value=param, dtype=torch_dtype
638
+ )
639
+ else:
640
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
641
+
642
+ if cls._keys_to_ignore_on_load_unexpected is not None:
643
+ for pat in cls._keys_to_ignore_on_load_unexpected:
644
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
645
+
646
+ if len(unexpected_keys) > 0:
647
+ logger.warn(
648
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
649
+ )
650
+
651
+ else: # else let accelerate handle loading and dispatching.
652
+ # Load weights and dispatch according to the device_map
653
+ # by default the device_map is None and the weights are loaded on the CPU
654
+ try:
655
+ accelerate.load_checkpoint_and_dispatch(
656
+ model,
657
+ model_file,
658
+ device_map,
659
+ max_memory=max_memory,
660
+ offload_folder=offload_folder,
661
+ offload_state_dict=offload_state_dict,
662
+ dtype=torch_dtype,
663
+ )
664
+ except AttributeError as e:
665
+ # When using accelerate loading, we do not have the ability to load the state
666
+ # dict and rename the weight names manually. Additionally, accelerate skips
667
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
668
+ # (which look like they should be private variables?), so we can't use the standard hooks
669
+ # to rename parameters on load. We need to mimic the original weight names so the correct
670
+ # attributes are available. After we have loaded the weights, we convert the deprecated
671
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
672
+ # the weights so we don't have to do this again.
673
+
674
+ if "'Attention' object has no attribute" in str(e):
675
+ logger.warn(
676
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
677
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
678
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
679
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
680
+ " please also re-upload it or open a PR on the original repository."
681
+ )
682
+ model._temp_convert_self_to_deprecated_attention_blocks()
683
+ accelerate.load_checkpoint_and_dispatch(
684
+ model,
685
+ model_file,
686
+ device_map,
687
+ max_memory=max_memory,
688
+ offload_folder=offload_folder,
689
+ offload_state_dict=offload_state_dict,
690
+ dtype=torch_dtype,
691
+ )
692
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
693
+ else:
694
+ raise e
695
+
696
+ loading_info = {
697
+ "missing_keys": [],
698
+ "unexpected_keys": [],
699
+ "mismatched_keys": [],
700
+ "error_msgs": [],
701
+ }
702
+ else:
703
+ model = cls.from_config(config, **unused_kwargs)
704
+
705
+ state_dict = load_state_dict(model_file, variant=variant)
706
+ model._convert_deprecated_attention_blocks(state_dict)
707
+
708
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
709
+ model,
710
+ state_dict,
711
+ model_file,
712
+ pretrained_model_name_or_path,
713
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
714
+ )
715
+
716
+ loading_info = {
717
+ "missing_keys": missing_keys,
718
+ "unexpected_keys": unexpected_keys,
719
+ "mismatched_keys": mismatched_keys,
720
+ "error_msgs": error_msgs,
721
+ }
722
+
723
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
724
+ raise ValueError(
725
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
726
+ )
727
+ elif torch_dtype is not None:
728
+ model = model.to(torch_dtype)
729
+
730
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
731
+
732
+ # Set model in evaluation mode to deactivate DropOut modules by default
733
+ model.eval()
734
+ if output_loading_info:
735
+ return model, loading_info
736
+
737
+ return model
738
+
739
+ @classmethod
740
+ def _load_pretrained_model(
741
+ cls,
742
+ model,
743
+ state_dict,
744
+ resolved_archive_file,
745
+ pretrained_model_name_or_path,
746
+ ignore_mismatched_sizes=False,
747
+ ):
748
+ # Retrieve missing & unexpected_keys
749
+ model_state_dict = model.state_dict()
750
+ loaded_keys = list(state_dict.keys())
751
+
752
+ expected_keys = list(model_state_dict.keys())
753
+
754
+ original_loaded_keys = loaded_keys
755
+
756
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
757
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
758
+
759
+ # Make sure we are able to load base models as well as derived models (with heads)
760
+ model_to_load = model
761
+
762
+ def _find_mismatched_keys(
763
+ state_dict,
764
+ model_state_dict,
765
+ loaded_keys,
766
+ ignore_mismatched_sizes,
767
+ ):
768
+ mismatched_keys = []
769
+ if ignore_mismatched_sizes:
770
+ for checkpoint_key in loaded_keys:
771
+ model_key = checkpoint_key
772
+
773
+ if (
774
+ model_key in model_state_dict
775
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
776
+ ):
777
+ mismatched_keys.append(
778
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
779
+ )
780
+ del state_dict[checkpoint_key]
781
+ return mismatched_keys
782
+
783
+ if state_dict is not None:
784
+ # Whole checkpoint
785
+ mismatched_keys = _find_mismatched_keys(
786
+ state_dict,
787
+ model_state_dict,
788
+ original_loaded_keys,
789
+ ignore_mismatched_sizes,
790
+ )
791
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
792
+
793
+ if len(error_msgs) > 0:
794
+ error_msg = "\n\t".join(error_msgs)
795
+ if "size mismatch" in error_msg:
796
+ error_msg += (
797
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
798
+ )
799
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
800
+
801
+ if len(unexpected_keys) > 0:
802
+ logger.warning(
803
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
804
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
805
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
806
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
807
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
808
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
809
+ " identical (initializing a BertForSequenceClassification model from a"
810
+ " BertForSequenceClassification model)."
811
+ )
812
+ else:
813
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
814
+ if len(missing_keys) > 0:
815
+ logger.warning(
816
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
817
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
818
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
819
+ )
820
+ elif len(mismatched_keys) == 0:
821
+ logger.info(
822
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
823
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
824
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
825
+ " without further training."
826
+ )
827
+ if len(mismatched_keys) > 0:
828
+ mismatched_warning = "\n".join(
829
+ [
830
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
831
+ for key, shape1, shape2 in mismatched_keys
832
+ ]
833
+ )
834
+ logger.warning(
835
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
836
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
837
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
838
+ " able to use it for predictions and inference."
839
+ )
840
+
841
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
842
+
843
+ @property
844
+ def device(self) -> device:
845
+ """
846
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
847
+ device).
848
+ """
849
+ return get_parameter_device(self)
850
+
851
+ @property
852
+ def dtype(self) -> torch.dtype:
853
+ """
854
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
855
+ """
856
+ return get_parameter_dtype(self)
857
+
858
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
859
+ """
860
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
861
+
862
+ Args:
863
+ only_trainable (`bool`, *optional*, defaults to `False`):
864
+ Whether or not to return only the number of trainable parameters
865
+
866
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
867
+ Whether or not to return only the number of non-embeddings parameters
868
+
869
+ Returns:
870
+ `int`: The number of parameters.
871
+ """
872
+
873
+ if exclude_embeddings:
874
+ embedding_param_names = [
875
+ f"{name}.weight"
876
+ for name, module_type in self.named_modules()
877
+ if isinstance(module_type, torch.nn.Embedding)
878
+ ]
879
+ non_embedding_parameters = [
880
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
881
+ ]
882
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
883
+ else:
884
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
885
+
886
+ def _convert_deprecated_attention_blocks(self, state_dict):
887
+ deprecated_attention_block_paths = []
888
+
889
+ def recursive_find_attn_block(name, module):
890
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
891
+ deprecated_attention_block_paths.append(name)
892
+
893
+ for sub_name, sub_module in module.named_children():
894
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
895
+ recursive_find_attn_block(sub_name, sub_module)
896
+
897
+ recursive_find_attn_block("", self)
898
+
899
+ # NOTE: we have to check if the deprecated parameters are in the state dict
900
+ # because it is possible we are loading from a state dict that was already
901
+ # converted
902
+
903
+ for path in deprecated_attention_block_paths:
904
+ # group_norm path stays the same
905
+
906
+ # query -> to_q
907
+ if f"{path}.query.weight" in state_dict:
908
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
909
+ if f"{path}.query.bias" in state_dict:
910
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
911
+
912
+ # key -> to_k
913
+ if f"{path}.key.weight" in state_dict:
914
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
915
+ if f"{path}.key.bias" in state_dict:
916
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
917
+
918
+ # value -> to_v
919
+ if f"{path}.value.weight" in state_dict:
920
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
921
+ if f"{path}.value.bias" in state_dict:
922
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
923
+
924
+ # proj_attn -> to_out.0
925
+ if f"{path}.proj_attn.weight" in state_dict:
926
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
927
+ if f"{path}.proj_attn.bias" in state_dict:
928
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
929
+
930
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
931
+ deprecated_attention_block_modules = []
932
+
933
+ def recursive_find_attn_block(module):
934
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
935
+ deprecated_attention_block_modules.append(module)
936
+
937
+ for sub_module in module.children():
938
+ recursive_find_attn_block(sub_module)
939
+
940
+ recursive_find_attn_block(self)
941
+
942
+ for module in deprecated_attention_block_modules:
943
+ module.query = module.to_q
944
+ module.key = module.to_k
945
+ module.value = module.to_v
946
+ module.proj_attn = module.to_out[0]
947
+
948
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
949
+ # that _all_ the weights are loaded into the new attributes and we're not
950
+ # making an incorrect assumption that this model should be converted when
951
+ # it really shouldn't be.
952
+ del module.to_q
953
+ del module.to_k
954
+ del module.to_v
955
+ del module.to_out
956
+
957
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
958
+ deprecated_attention_block_modules = []
959
+
960
+ def recursive_find_attn_block(module):
961
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
962
+ deprecated_attention_block_modules.append(module)
963
+
964
+ for sub_module in module.children():
965
+ recursive_find_attn_block(sub_module)
966
+
967
+ recursive_find_attn_block(self)
968
+
969
+ for module in deprecated_attention_block_modules:
970
+ module.to_q = module.query
971
+ module.to_k = module.key
972
+ module.to_v = module.value
973
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
974
+
975
+ del module.query
976
+ del module.key
977
+ del module.value
978
+ del module.proj_attn
diffusers/models/prior_transformer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Union
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from ..utils.configuration_utils import ConfigMixin, register_to_config
9
+ from ..utils.outputs import BaseOutput
10
+ from .attention import BasicTransformerBlock
11
+ from .embeddings import TimestepEmbedding, Timesteps
12
+ from .modeling_utils import ModelMixin
13
+
14
+
15
+ @dataclass
16
+ class PriorTransformerOutput(BaseOutput):
17
+ """
18
+ Args:
19
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
20
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
21
+ """
22
+
23
+ predicted_image_embedding: torch.FloatTensor
24
+
25
+
26
+ class PriorTransformer(ModelMixin, ConfigMixin):
27
+ """
28
+ The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
29
+ transformer predicts the image embeddings through a denoising diffusion process.
30
+
31
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
32
+ implements for all the models (such as downloading or saving, etc.)
33
+
34
+ For more details, see the original paper: https://arxiv.org/abs/2204.06125
35
+
36
+ Parameters:
37
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
38
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
39
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
40
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
41
+ image embeddings and text embeddings are both the same dimension.
42
+ num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
43
+ length of the prompt after it has been tokenized.
44
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45
+ projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
46
+ additional_embeddings`.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+
49
+ """
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ num_attention_heads: int = 32,
55
+ attention_head_dim: int = 64,
56
+ num_layers: int = 20,
57
+ embedding_dim: int = 768,
58
+ num_embeddings=77,
59
+ additional_embeddings=4,
60
+ dropout: float = 0.0,
61
+ ):
62
+ super().__init__()
63
+ self.num_attention_heads = num_attention_heads
64
+ self.attention_head_dim = attention_head_dim
65
+ inner_dim = num_attention_heads * attention_head_dim
66
+ self.additional_embeddings = additional_embeddings
67
+
68
+ self.time_proj = Timesteps(inner_dim, True, 0)
69
+ self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
70
+
71
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
72
+
73
+ self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
74
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
75
+
76
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
77
+
78
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
79
+
80
+ self.transformer_blocks = nn.ModuleList(
81
+ [
82
+ BasicTransformerBlock(
83
+ inner_dim,
84
+ num_attention_heads,
85
+ attention_head_dim,
86
+ dropout=dropout,
87
+ activation_fn="gelu",
88
+ attention_bias=True,
89
+ )
90
+ for d in range(num_layers)
91
+ ]
92
+ )
93
+
94
+ self.norm_out = nn.LayerNorm(inner_dim)
95
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
96
+
97
+ causal_attention_mask = torch.full(
98
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
99
+ )
100
+ causal_attention_mask.triu_(1)
101
+ causal_attention_mask = causal_attention_mask[None, ...]
102
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
103
+
104
+ self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
105
+ self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
106
+
107
+ def forward(
108
+ self,
109
+ hidden_states,
110
+ timestep: Union[torch.Tensor, float, int],
111
+ proj_embedding: torch.FloatTensor,
112
+ encoder_hidden_states: torch.FloatTensor,
113
+ attention_mask: Optional[torch.BoolTensor] = None,
114
+ return_dict: bool = True,
115
+ ):
116
+ """
117
+ Args:
118
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
119
+ x_t, the currently predicted image embeddings.
120
+ timestep (`torch.long`):
121
+ Current denoising step.
122
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
123
+ Projected embedding vector the denoising process is conditioned on.
124
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
125
+ Hidden states of the text embeddings the denoising process is conditioned on.
126
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
127
+ Text mask for the text embeddings.
128
+ return_dict (`bool`, *optional*, defaults to `True`):
129
+ Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
130
+ tuple.
131
+
132
+ Returns:
133
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
134
+ [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
135
+ returning a tuple, the first element is the sample tensor.
136
+ """
137
+ batch_size = hidden_states.shape[0]
138
+
139
+ timesteps = timestep
140
+ if not torch.is_tensor(timesteps):
141
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
142
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
143
+ timesteps = timesteps[None].to(hidden_states.device)
144
+
145
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
146
+ timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
147
+
148
+ timesteps_projected = self.time_proj(timesteps)
149
+
150
+ # timesteps does not contain any weights and will always return f32 tensors
151
+ # but time_embedding might be fp16, so we need to cast here.
152
+ timesteps_projected = timesteps_projected.to(dtype=self.dtype)
153
+ time_embeddings = self.time_embedding(timesteps_projected)
154
+
155
+ proj_embeddings = self.embedding_proj(proj_embedding)
156
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
157
+ hidden_states = self.proj_in(hidden_states)
158
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
159
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
160
+
161
+ hidden_states = torch.cat(
162
+ [
163
+ encoder_hidden_states,
164
+ proj_embeddings[:, None, :],
165
+ time_embeddings[:, None, :],
166
+ hidden_states[:, None, :],
167
+ prd_embedding,
168
+ ],
169
+ dim=1,
170
+ )
171
+
172
+ hidden_states = hidden_states + positional_embeddings
173
+
174
+ if attention_mask is not None:
175
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
176
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
177
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
178
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
179
+
180
+ for block in self.transformer_blocks:
181
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
182
+
183
+ hidden_states = self.norm_out(hidden_states)
184
+ hidden_states = hidden_states[:, -1]
185
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
186
+
187
+ if not return_dict:
188
+ return (predicted_image_embedding,)
189
+
190
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
191
+
192
+ def post_process_latents(self, prior_latents):
193
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
194
+ return prior_latents
diffusers/models/resnet.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ # `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from functools import partial
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ from .attention import AdaGroupNorm
24
+
25
+
26
+ class Upsample1D(nn.Module):
27
+ """
28
+ An upsampling layer with an optional convolution.
29
+
30
+ Parameters:
31
+ channels: channels in the inputs and outputs.
32
+ use_conv: a bool determining if a convolution is applied.
33
+ use_conv_transpose:
34
+ out_channels:
35
+ """
36
+
37
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
38
+ super().__init__()
39
+ self.channels = channels
40
+ self.out_channels = out_channels or channels
41
+ self.use_conv = use_conv
42
+ self.use_conv_transpose = use_conv_transpose
43
+ self.name = name
44
+
45
+ self.conv = None
46
+ if use_conv_transpose:
47
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
48
+ elif use_conv:
49
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
50
+
51
+ def forward(self, x):
52
+ assert x.shape[1] == self.channels
53
+ if self.use_conv_transpose:
54
+ return self.conv(x)
55
+
56
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
57
+
58
+ if self.use_conv:
59
+ x = self.conv(x)
60
+
61
+ return x
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ """
66
+ A downsampling layer with an optional convolution.
67
+
68
+ Parameters:
69
+ channels: channels in the inputs and outputs.
70
+ use_conv: a bool determining if a convolution is applied.
71
+ out_channels:
72
+ padding:
73
+ """
74
+
75
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
76
+ super().__init__()
77
+ self.channels = channels
78
+ self.out_channels = out_channels or channels
79
+ self.use_conv = use_conv
80
+ self.padding = padding
81
+ stride = 2
82
+ self.name = name
83
+
84
+ if use_conv:
85
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
86
+ else:
87
+ assert self.channels == self.out_channels
88
+ self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
89
+
90
+ def forward(self, x):
91
+ assert x.shape[1] == self.channels
92
+ return self.conv(x)
93
+
94
+
95
+ class Upsample2D(nn.Module):
96
+ """
97
+ An upsampling layer with an optional convolution.
98
+
99
+ Parameters:
100
+ channels: channels in the inputs and outputs.
101
+ use_conv: a bool determining if a convolution is applied.
102
+ use_conv_transpose:
103
+ out_channels:
104
+ """
105
+
106
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.out_channels = out_channels or channels
110
+ self.use_conv = use_conv
111
+ self.use_conv_transpose = use_conv_transpose
112
+ self.name = name
113
+
114
+ conv = None
115
+ if use_conv_transpose:
116
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
117
+ elif use_conv:
118
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
119
+
120
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
121
+ if name == "conv":
122
+ self.conv = conv
123
+ else:
124
+ self.Conv2d_0 = conv
125
+
126
+ def forward(self, hidden_states, output_size=None):
127
+ assert hidden_states.shape[1] == self.channels
128
+
129
+ if self.use_conv_transpose:
130
+ return self.conv(hidden_states)
131
+
132
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
133
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
134
+ # https://github.com/pytorch/pytorch/issues/86679
135
+ dtype = hidden_states.dtype
136
+ if dtype == torch.bfloat16:
137
+ hidden_states = hidden_states.to(torch.float32)
138
+
139
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
140
+ if hidden_states.shape[0] >= 64:
141
+ hidden_states = hidden_states.contiguous()
142
+
143
+ # if `output_size` is passed we force the interpolation output
144
+ # size and do not make use of `scale_factor=2`
145
+ if output_size is None:
146
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
147
+ else:
148
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
149
+
150
+ # If the input is bfloat16, we cast back to bfloat16
151
+ if dtype == torch.bfloat16:
152
+ hidden_states = hidden_states.to(dtype)
153
+
154
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
155
+ if self.use_conv:
156
+ if self.name == "conv":
157
+ hidden_states = self.conv(hidden_states)
158
+ else:
159
+ hidden_states = self.Conv2d_0(hidden_states)
160
+
161
+ return hidden_states
162
+
163
+
164
+ class Downsample2D(nn.Module):
165
+ """
166
+ A downsampling layer with an optional convolution.
167
+
168
+ Parameters:
169
+ channels: channels in the inputs and outputs.
170
+ use_conv: a bool determining if a convolution is applied.
171
+ out_channels:
172
+ padding:
173
+ """
174
+
175
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
176
+ super().__init__()
177
+ self.channels = channels
178
+ self.out_channels = out_channels or channels
179
+ self.use_conv = use_conv
180
+ self.padding = padding
181
+ stride = 2
182
+ self.name = name
183
+
184
+ if use_conv:
185
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
186
+ else:
187
+ assert self.channels == self.out_channels
188
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
189
+
190
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
191
+ if name == "conv":
192
+ self.Conv2d_0 = conv
193
+ self.conv = conv
194
+ elif name == "Conv2d_0":
195
+ self.conv = conv
196
+ else:
197
+ self.conv = conv
198
+
199
+ def forward(self, hidden_states):
200
+ assert hidden_states.shape[1] == self.channels
201
+ if self.use_conv and self.padding == 0:
202
+ pad = (0, 1, 0, 1)
203
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
204
+
205
+ assert hidden_states.shape[1] == self.channels
206
+ hidden_states = self.conv(hidden_states)
207
+
208
+ return hidden_states
209
+
210
+
211
+ class FirUpsample2D(nn.Module):
212
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
213
+ super().__init__()
214
+ out_channels = out_channels if out_channels else channels
215
+ if use_conv:
216
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
217
+ self.use_conv = use_conv
218
+ self.fir_kernel = fir_kernel
219
+ self.out_channels = out_channels
220
+
221
+ def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
222
+ """Fused `upsample_2d()` followed by `Conv2d()`.
223
+
224
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
225
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
226
+ arbitrary order.
227
+
228
+ Args:
229
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
230
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
231
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
232
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
233
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
234
+ factor: Integer upsampling factor (default: 2).
235
+ gain: Scaling factor for signal magnitude (default: 1.0).
236
+
237
+ Returns:
238
+ output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
239
+ datatype as `hidden_states`.
240
+ """
241
+
242
+ assert isinstance(factor, int) and factor >= 1
243
+
244
+ # Setup filter kernel.
245
+ if kernel is None:
246
+ kernel = [1] * factor
247
+
248
+ # setup kernel
249
+ kernel = torch.tensor(kernel, dtype=torch.float32)
250
+ if kernel.ndim == 1:
251
+ kernel = torch.outer(kernel, kernel)
252
+ kernel /= torch.sum(kernel)
253
+
254
+ kernel = kernel * (gain * (factor**2))
255
+
256
+ if self.use_conv:
257
+ convH = weight.shape[2]
258
+ convW = weight.shape[3]
259
+ inC = weight.shape[1]
260
+
261
+ pad_value = (kernel.shape[0] - factor) - (convW - 1)
262
+
263
+ stride = (factor, factor)
264
+ # Determine data dimensions.
265
+ output_shape = (
266
+ (hidden_states.shape[2] - 1) * factor + convH,
267
+ (hidden_states.shape[3] - 1) * factor + convW,
268
+ )
269
+ output_padding = (
270
+ output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
271
+ output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
272
+ )
273
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
274
+ num_groups = hidden_states.shape[1] // inC
275
+
276
+ # Transpose weights.
277
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
278
+ weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
279
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
280
+
281
+ inverse_conv = F.conv_transpose2d(
282
+ hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
283
+ )
284
+
285
+ output = upfirdn2d_native(
286
+ inverse_conv,
287
+ torch.tensor(kernel, device=inverse_conv.device),
288
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
289
+ )
290
+ else:
291
+ pad_value = kernel.shape[0] - factor
292
+ output = upfirdn2d_native(
293
+ hidden_states,
294
+ torch.tensor(kernel, device=hidden_states.device),
295
+ up=factor,
296
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
297
+ )
298
+
299
+ return output
300
+
301
+ def forward(self, hidden_states):
302
+ if self.use_conv:
303
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
304
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
305
+ else:
306
+ height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
307
+
308
+ return height
309
+
310
+
311
+ class FirDownsample2D(nn.Module):
312
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
313
+ super().__init__()
314
+ out_channels = out_channels if out_channels else channels
315
+ if use_conv:
316
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
317
+ self.fir_kernel = fir_kernel
318
+ self.use_conv = use_conv
319
+ self.out_channels = out_channels
320
+
321
+ def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
322
+ """Fused `Conv2d()` followed by `downsample_2d()`.
323
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
324
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
325
+ arbitrary order.
326
+
327
+ Args:
328
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
329
+ weight:
330
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
331
+ performed by `inChannels = x.shape[0] // numGroups`.
332
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
333
+ factor`, which corresponds to average pooling.
334
+ factor: Integer downsampling factor (default: 2).
335
+ gain: Scaling factor for signal magnitude (default: 1.0).
336
+
337
+ Returns:
338
+ output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
339
+ same datatype as `x`.
340
+ """
341
+
342
+ assert isinstance(factor, int) and factor >= 1
343
+ if kernel is None:
344
+ kernel = [1] * factor
345
+
346
+ # setup kernel
347
+ kernel = torch.tensor(kernel, dtype=torch.float32)
348
+ if kernel.ndim == 1:
349
+ kernel = torch.outer(kernel, kernel)
350
+ kernel /= torch.sum(kernel)
351
+
352
+ kernel = kernel * gain
353
+
354
+ if self.use_conv:
355
+ _, _, convH, convW = weight.shape
356
+ pad_value = (kernel.shape[0] - factor) + (convW - 1)
357
+ stride_value = [factor, factor]
358
+ upfirdn_input = upfirdn2d_native(
359
+ hidden_states,
360
+ torch.tensor(kernel, device=hidden_states.device),
361
+ pad=((pad_value + 1) // 2, pad_value // 2),
362
+ )
363
+ output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
364
+ else:
365
+ pad_value = kernel.shape[0] - factor
366
+ output = upfirdn2d_native(
367
+ hidden_states,
368
+ torch.tensor(kernel, device=hidden_states.device),
369
+ down=factor,
370
+ pad=((pad_value + 1) // 2, pad_value // 2),
371
+ )
372
+
373
+ return output
374
+
375
+ def forward(self, hidden_states):
376
+ if self.use_conv:
377
+ downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
378
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
379
+ else:
380
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
381
+
382
+ return hidden_states
383
+
384
+
385
+ # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
386
+ class KDownsample2D(nn.Module):
387
+ def __init__(self, pad_mode="reflect"):
388
+ super().__init__()
389
+ self.pad_mode = pad_mode
390
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
391
+ self.pad = kernel_1d.shape[1] // 2 - 1
392
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
393
+
394
+ def forward(self, x):
395
+ x = F.pad(x, (self.pad,) * 4, self.pad_mode)
396
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
397
+ indices = torch.arange(x.shape[1], device=x.device)
398
+ weight[indices, indices] = self.kernel.to(weight)
399
+ return F.conv2d(x, weight, stride=2)
400
+
401
+
402
+ class KUpsample2D(nn.Module):
403
+ def __init__(self, pad_mode="reflect"):
404
+ super().__init__()
405
+ self.pad_mode = pad_mode
406
+ kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
407
+ self.pad = kernel_1d.shape[1] // 2 - 1
408
+ self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
409
+
410
+ def forward(self, x):
411
+ x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
412
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
413
+ indices = torch.arange(x.shape[1], device=x.device)
414
+ weight[indices, indices] = self.kernel.to(weight)
415
+ return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
416
+
417
+
418
+ class ResnetBlock2D(nn.Module):
419
+ r"""
420
+ A Resnet block.
421
+
422
+ Parameters:
423
+ in_channels (`int`): The number of channels in the input.
424
+ out_channels (`int`, *optional*, default to be `None`):
425
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
426
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
427
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
428
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
429
+ groups_out (`int`, *optional*, default to None):
430
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
431
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
432
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
433
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
434
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
435
+ "ada_group" for a stronger conditioning with scale and shift.
436
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
437
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
438
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
439
+ use_in_shortcut (`bool`, *optional*, default to `True`):
440
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
441
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
442
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
443
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
444
+ `conv_shortcut` output.
445
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
446
+ If None, same as `out_channels`.
447
+ """
448
+
449
+ def __init__(
450
+ self,
451
+ *,
452
+ in_channels,
453
+ out_channels=None,
454
+ conv_shortcut=False,
455
+ dropout=0.0,
456
+ temb_channels=512,
457
+ groups=32,
458
+ groups_out=None,
459
+ pre_norm=True,
460
+ eps=1e-6,
461
+ non_linearity="swish",
462
+ time_embedding_norm="default", # default, scale_shift, ada_group
463
+ kernel=None,
464
+ output_scale_factor=1.0,
465
+ use_in_shortcut=None,
466
+ up=False,
467
+ down=False,
468
+ conv_shortcut_bias: bool = True,
469
+ conv_2d_out_channels: Optional[int] = None,
470
+ ):
471
+ super().__init__()
472
+ self.pre_norm = pre_norm
473
+ self.pre_norm = True
474
+ self.in_channels = in_channels
475
+ out_channels = in_channels if out_channels is None else out_channels
476
+ self.out_channels = out_channels
477
+ self.use_conv_shortcut = conv_shortcut
478
+ self.up = up
479
+ self.down = down
480
+ self.output_scale_factor = output_scale_factor
481
+ self.time_embedding_norm = time_embedding_norm
482
+
483
+ if groups_out is None:
484
+ groups_out = groups
485
+
486
+ if self.time_embedding_norm == "ada_group":
487
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
488
+ else:
489
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
490
+
491
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
492
+
493
+ if temb_channels is not None:
494
+ if self.time_embedding_norm == "default":
495
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
496
+ elif self.time_embedding_norm == "scale_shift":
497
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
498
+ elif self.time_embedding_norm == "ada_group":
499
+ self.time_emb_proj = None
500
+ else:
501
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
502
+ else:
503
+ self.time_emb_proj = None
504
+
505
+ if self.time_embedding_norm == "ada_group":
506
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
507
+ else:
508
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
509
+
510
+ self.dropout = torch.nn.Dropout(dropout)
511
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
512
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
513
+
514
+ if non_linearity == "swish":
515
+ self.nonlinearity = lambda x: F.silu(x)
516
+ elif non_linearity == "mish":
517
+ self.nonlinearity = nn.Mish()
518
+ elif non_linearity == "silu":
519
+ self.nonlinearity = nn.SiLU()
520
+ elif non_linearity == "gelu":
521
+ self.nonlinearity = nn.GELU()
522
+
523
+ self.upsample = self.downsample = None
524
+ if self.up:
525
+ if kernel == "fir":
526
+ fir_kernel = (1, 3, 3, 1)
527
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
528
+ elif kernel == "sde_vp":
529
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
530
+ else:
531
+ self.upsample = Upsample2D(in_channels, use_conv=False)
532
+ elif self.down:
533
+ if kernel == "fir":
534
+ fir_kernel = (1, 3, 3, 1)
535
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
536
+ elif kernel == "sde_vp":
537
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
538
+ else:
539
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
540
+
541
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
542
+
543
+ self.conv_shortcut = None
544
+ if self.use_in_shortcut:
545
+ self.conv_shortcut = torch.nn.Conv2d(
546
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
547
+ )
548
+
549
+ def forward(self, input_tensor, temb):
550
+ hidden_states = input_tensor
551
+
552
+ if self.time_embedding_norm == "ada_group":
553
+ hidden_states = self.norm1(hidden_states, temb)
554
+ else:
555
+ hidden_states = self.norm1(hidden_states)
556
+
557
+ hidden_states = self.nonlinearity(hidden_states)
558
+
559
+ if self.upsample is not None:
560
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
561
+ if hidden_states.shape[0] >= 64:
562
+ input_tensor = input_tensor.contiguous()
563
+ hidden_states = hidden_states.contiguous()
564
+ input_tensor = self.upsample(input_tensor)
565
+ hidden_states = self.upsample(hidden_states)
566
+ elif self.downsample is not None:
567
+ input_tensor = self.downsample(input_tensor)
568
+ hidden_states = self.downsample(hidden_states)
569
+
570
+ hidden_states = self.conv1(hidden_states)
571
+
572
+ if self.time_emb_proj is not None:
573
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
574
+
575
+ if temb is not None and self.time_embedding_norm == "default":
576
+ hidden_states = hidden_states + temb
577
+
578
+ if self.time_embedding_norm == "ada_group":
579
+ hidden_states = self.norm2(hidden_states, temb)
580
+ else:
581
+ hidden_states = self.norm2(hidden_states)
582
+
583
+ if temb is not None and self.time_embedding_norm == "scale_shift":
584
+ scale, shift = torch.chunk(temb, 2, dim=1)
585
+ hidden_states = hidden_states * (1 + scale) + shift
586
+
587
+ hidden_states = self.nonlinearity(hidden_states)
588
+
589
+ hidden_states = self.dropout(hidden_states)
590
+ hidden_states = self.conv2(hidden_states)
591
+
592
+ if self.conv_shortcut is not None:
593
+ input_tensor = self.conv_shortcut(input_tensor)
594
+
595
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
596
+
597
+ return output_tensor
598
+
599
+
600
+ class Mish(torch.nn.Module):
601
+ def forward(self, hidden_states):
602
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
603
+
604
+
605
+ # unet_rl.py
606
+ def rearrange_dims(tensor):
607
+ if len(tensor.shape) == 2:
608
+ return tensor[:, :, None]
609
+ if len(tensor.shape) == 3:
610
+ return tensor[:, :, None, :]
611
+ elif len(tensor.shape) == 4:
612
+ return tensor[:, :, 0, :]
613
+ else:
614
+ raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
615
+
616
+
617
+ class Conv1dBlock(nn.Module):
618
+ """
619
+ Conv1d --> GroupNorm --> Mish
620
+ """
621
+
622
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
623
+ super().__init__()
624
+
625
+ self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
626
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
627
+ self.mish = nn.Mish()
628
+
629
+ def forward(self, x):
630
+ x = self.conv1d(x)
631
+ x = rearrange_dims(x)
632
+ x = self.group_norm(x)
633
+ x = rearrange_dims(x)
634
+ x = self.mish(x)
635
+ return x
636
+
637
+
638
+ # unet_rl.py
639
+ class ResidualTemporalBlock1D(nn.Module):
640
+ def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
641
+ super().__init__()
642
+ self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
643
+ self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
644
+
645
+ self.time_emb_act = nn.Mish()
646
+ self.time_emb = nn.Linear(embed_dim, out_channels)
647
+
648
+ self.residual_conv = (
649
+ nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
650
+ )
651
+
652
+ def forward(self, x, t):
653
+ """
654
+ Args:
655
+ x : [ batch_size x inp_channels x horizon ]
656
+ t : [ batch_size x embed_dim ]
657
+
658
+ returns:
659
+ out : [ batch_size x out_channels x horizon ]
660
+ """
661
+ t = self.time_emb_act(t)
662
+ t = self.time_emb(t)
663
+ out = self.conv_in(x) + rearrange_dims(t)
664
+ out = self.conv_out(out)
665
+ return out + self.residual_conv(x)
666
+
667
+
668
+ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
669
+ r"""Upsample2D a batch of 2D images with the given filter.
670
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
671
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
672
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
673
+ a: multiple of the upsampling factor.
674
+
675
+ Args:
676
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
677
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
678
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
679
+ factor: Integer upsampling factor (default: 2).
680
+ gain: Scaling factor for signal magnitude (default: 1.0).
681
+
682
+ Returns:
683
+ output: Tensor of the shape `[N, C, H * factor, W * factor]`
684
+ """
685
+ assert isinstance(factor, int) and factor >= 1
686
+ if kernel is None:
687
+ kernel = [1] * factor
688
+
689
+ kernel = torch.tensor(kernel, dtype=torch.float32)
690
+ if kernel.ndim == 1:
691
+ kernel = torch.outer(kernel, kernel)
692
+ kernel /= torch.sum(kernel)
693
+
694
+ kernel = kernel * (gain * (factor**2))
695
+ pad_value = kernel.shape[0] - factor
696
+ output = upfirdn2d_native(
697
+ hidden_states,
698
+ kernel.to(device=hidden_states.device),
699
+ up=factor,
700
+ pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
701
+ )
702
+ return output
703
+
704
+
705
+ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
706
+ r"""Downsample2D a batch of 2D images with the given filter.
707
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
708
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
709
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
710
+ shape is a multiple of the downsampling factor.
711
+
712
+ Args:
713
+ hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
714
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
715
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
716
+ factor: Integer downsampling factor (default: 2).
717
+ gain: Scaling factor for signal magnitude (default: 1.0).
718
+
719
+ Returns:
720
+ output: Tensor of the shape `[N, C, H // factor, W // factor]`
721
+ """
722
+
723
+ assert isinstance(factor, int) and factor >= 1
724
+ if kernel is None:
725
+ kernel = [1] * factor
726
+
727
+ kernel = torch.tensor(kernel, dtype=torch.float32)
728
+ if kernel.ndim == 1:
729
+ kernel = torch.outer(kernel, kernel)
730
+ kernel /= torch.sum(kernel)
731
+
732
+ kernel = kernel * gain
733
+ pad_value = kernel.shape[0] - factor
734
+ output = upfirdn2d_native(
735
+ hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
736
+ )
737
+ return output
738
+
739
+
740
+ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
741
+ up_x = up_y = up
742
+ down_x = down_y = down
743
+ pad_x0 = pad_y0 = pad[0]
744
+ pad_x1 = pad_y1 = pad[1]
745
+
746
+ _, channel, in_h, in_w = tensor.shape
747
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
748
+
749
+ _, in_h, in_w, minor = tensor.shape
750
+ kernel_h, kernel_w = kernel.shape
751
+
752
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
753
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
754
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
755
+
756
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
757
+ out = out.to(tensor.device) # Move back to mps if necessary
758
+ out = out[
759
+ :,
760
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
761
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
762
+ :,
763
+ ]
764
+
765
+ out = out.permute(0, 3, 1, 2)
766
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
767
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
768
+ out = F.conv2d(out, w)
769
+ out = out.reshape(
770
+ -1,
771
+ minor,
772
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
773
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
774
+ )
775
+ out = out.permute(0, 2, 3, 1)
776
+ out = out[:, ::down_y, ::down_x, :]
777
+
778
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
779
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
780
+
781
+ return out.view(-1, channel, out_h, out_w)
782
+
783
+
784
+ class TemporalConvLayer(nn.Module):
785
+ """
786
+ Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
787
+ https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
788
+ """
789
+
790
+ def __init__(self, in_dim, out_dim=None, dropout=0.0):
791
+ super().__init__()
792
+ out_dim = out_dim or in_dim
793
+ self.in_dim = in_dim
794
+ self.out_dim = out_dim
795
+
796
+ # conv layers
797
+ self.conv1 = nn.Sequential(
798
+ nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
799
+ )
800
+ self.conv2 = nn.Sequential(
801
+ nn.GroupNorm(32, out_dim),
802
+ nn.SiLU(),
803
+ nn.Dropout(dropout),
804
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
805
+ )
806
+ self.conv3 = nn.Sequential(
807
+ nn.GroupNorm(32, out_dim),
808
+ nn.SiLU(),
809
+ nn.Dropout(dropout),
810
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
811
+ )
812
+ self.conv4 = nn.Sequential(
813
+ nn.GroupNorm(32, out_dim),
814
+ nn.SiLU(),
815
+ nn.Dropout(dropout),
816
+ nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
817
+ )
818
+
819
+ # zero out the last layer params,so the conv block is identity
820
+ nn.init.zeros_(self.conv4[-1].weight)
821
+ nn.init.zeros_(self.conv4[-1].bias)
822
+
823
+ def forward(self, hidden_states, num_frames=1):
824
+ hidden_states = (
825
+ hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
826
+ )
827
+
828
+ identity = hidden_states
829
+ hidden_states = self.conv1(hidden_states)
830
+ hidden_states = self.conv2(hidden_states)
831
+ hidden_states = self.conv3(hidden_states)
832
+ hidden_states = self.conv4(hidden_states)
833
+
834
+ hidden_states = identity + hidden_states
835
+
836
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
837
+ (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
838
+ )
839
+ return hidden_states
diffusers/models/transformer_2d.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from ..utils.configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils.outputs import BaseOutput
24
+ from ..utils.deprecation_utils import deprecate
25
+ from ..models.embeddings import ImagePositionalEmbeddings
26
+ from .attention import BasicTransformerBlock
27
+ from .embeddings import PatchEmbed
28
+ from .modeling_utils import ModelMixin
29
+
30
+
31
+ @dataclass
32
+ class Transformer2DModelOutput(BaseOutput):
33
+ """
34
+ Args:
35
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or
36
+ `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
37
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
38
+ for the unnoised latent pixels.
39
+ """
40
+
41
+ sample: torch.FloatTensor
42
+
43
+
44
+ class Transformer2DModel(ModelMixin, ConfigMixin):
45
+ """
46
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
47
+ embeddings) inputs.
48
+
49
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
50
+ transformer action. Finally, reshape to image.
51
+
52
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
53
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
54
+ classes of unnoised image.
55
+
56
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
57
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
58
+
59
+ Parameters:
60
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
61
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
62
+ in_channels (`int`, *optional*):
63
+ Pass if the input is continuous. The number of channels in the input and output.
64
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
65
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
67
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
68
+ Note that this is fixed at training time as it is used for learning a number of position embeddings.
69
+ See `ImagePositionalEmbeddings`.
70
+ num_vector_embeds (`int`, *optional*):
71
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
72
+ Includes the class for the masked latent pixel.
73
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
74
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
75
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is
76
+ used to learn a number of embeddings that are added to the hidden states. During inference, you can
77
+ denoise for up to but not more than steps than `num_embeds_ada_norm`.
78
+ attention_bias (`bool`, *optional*):
79
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
80
+ """
81
+
82
+ @register_to_config
83
+ def __init__(
84
+ self,
85
+ num_attention_heads: int = 16,
86
+ attention_head_dim: int = 88,
87
+ in_channels: Optional[int] = None,
88
+ out_channels: Optional[int] = None,
89
+ num_layers: int = 1,
90
+ dropout: float = 0.0,
91
+ norm_num_groups: int = 32,
92
+ cross_attention_dim: Optional[int] = None,
93
+ attention_bias: bool = False,
94
+ sample_size: Optional[int] = None,
95
+ num_vector_embeds: Optional[int] = None,
96
+ patch_size: Optional[int] = None,
97
+ activation_fn: str = "geglu",
98
+ num_embeds_ada_norm: Optional[int] = None,
99
+ use_linear_projection: bool = False,
100
+ only_cross_attention: bool = False,
101
+ upcast_attention: bool = False,
102
+ norm_type: str = "layer_norm",
103
+ norm_elementwise_affine: bool = True,
104
+ ):
105
+ super().__init__()
106
+ self.use_linear_projection = use_linear_projection
107
+ self.num_attention_heads = num_attention_heads
108
+ self.attention_head_dim = attention_head_dim
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+
111
+ # 1. Transformer2DModel can process both standard continuous images of
112
+ # shape `(batch_size, num_channels, width, height)` as well as
113
+ # quantized image embeddings of shape `(batch_size, num_image_vectors)`
114
+ # Define whether input is continuous or discrete depending on configuration
115
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
116
+ self.is_input_vectorized = num_vector_embeds is not None
117
+ self.is_input_patches = in_channels is not None and patch_size is not None
118
+
119
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
120
+ deprecation_message = (
121
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
122
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
123
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
124
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
125
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
126
+ )
127
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
128
+ norm_type = "ada_norm"
129
+
130
+ if self.is_input_continuous and self.is_input_vectorized:
131
+ raise ValueError(
132
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
133
+ " sure that either `in_channels` or `num_vector_embeds` is None."
134
+ )
135
+ elif self.is_input_vectorized and self.is_input_patches:
136
+ raise ValueError(
137
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
138
+ " sure that either `num_vector_embeds` or `num_patches` is None."
139
+ )
140
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
141
+ raise ValueError(
142
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
143
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
144
+ )
145
+
146
+ # 2. Define input layers
147
+ if self.is_input_continuous:
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
151
+ if use_linear_projection:
152
+ self.proj_in = nn.Linear(in_channels, inner_dim)
153
+ else:
154
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
155
+ elif self.is_input_vectorized:
156
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
157
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
158
+
159
+ self.height = sample_size
160
+ self.width = sample_size
161
+ self.num_vector_embeds = num_vector_embeds
162
+ self.num_latent_pixels = self.height * self.width
163
+
164
+ self.latent_image_embedding = ImagePositionalEmbeddings(
165
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
166
+ )
167
+ elif self.is_input_patches:
168
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
169
+
170
+ self.height = sample_size
171
+ self.width = sample_size
172
+
173
+ self.patch_size = patch_size
174
+ self.pos_embed = PatchEmbed(
175
+ height=sample_size,
176
+ width=sample_size,
177
+ patch_size=patch_size,
178
+ in_channels=in_channels,
179
+ embed_dim=inner_dim,
180
+ )
181
+
182
+ # 3. Define transformers blocks
183
+ self.transformer_blocks = nn.ModuleList(
184
+ [
185
+ BasicTransformerBlock(
186
+ inner_dim,
187
+ num_attention_heads,
188
+ attention_head_dim,
189
+ dropout=dropout,
190
+ cross_attention_dim=cross_attention_dim,
191
+ activation_fn=activation_fn,
192
+ num_embeds_ada_norm=num_embeds_ada_norm,
193
+ attention_bias=attention_bias,
194
+ only_cross_attention=only_cross_attention,
195
+ upcast_attention=upcast_attention,
196
+ norm_type=norm_type,
197
+ norm_elementwise_affine=norm_elementwise_affine,
198
+ )
199
+ for d in range(num_layers)
200
+ ]
201
+ )
202
+
203
+ # 4. Define output layers
204
+ self.out_channels = in_channels if out_channels is None else out_channels
205
+ if self.is_input_continuous:
206
+ # TODO: should use out_channels for continuous projections
207
+ if use_linear_projection:
208
+ self.proj_out = nn.Linear(inner_dim, in_channels)
209
+ else:
210
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
211
+ elif self.is_input_vectorized:
212
+ self.norm_out = nn.LayerNorm(inner_dim)
213
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
214
+ elif self.is_input_patches:
215
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
216
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
217
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_states: torch.Tensor,
222
+ encoder_hidden_states: Optional[torch.Tensor] = None,
223
+ timestep: Optional[torch.LongTensor] = None,
224
+ class_labels: Optional[torch.LongTensor] = None,
225
+ cross_attention_kwargs: Dict[str, Any] = None,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ encoder_attention_mask: Optional[torch.Tensor] = None,
228
+ return_dict: bool = True,
229
+ ):
230
+ """
231
+ Args:
232
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
233
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
234
+ hidden_states
235
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
236
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
237
+ self-attention.
238
+ timestep ( `torch.LongTensor`, *optional*):
239
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
240
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
241
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class
242
+ labels conditioning.
243
+ attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ).
244
+ Bias to add to attention scores.
245
+ encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ).
246
+ Bias to add to cross-attention scores.
247
+ return_dict (`bool`, *optional*, defaults to `True`):
248
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
249
+
250
+ Returns:
251
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
252
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`.
253
+ When returning a tuple, the first element is the sample tensor.
254
+ """
255
+ # 1. Input
256
+ if self.is_input_continuous:
257
+ batch, _, height, width = hidden_states.shape
258
+ residual = hidden_states
259
+
260
+ hidden_states = self.norm(hidden_states)
261
+ if not self.use_linear_projection:
262
+ hidden_states = self.proj_in(hidden_states)
263
+ inner_dim = hidden_states.shape[1]
264
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
265
+ else:
266
+ inner_dim = hidden_states.shape[1]
267
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
268
+ hidden_states = self.proj_in(hidden_states)
269
+
270
+ elif self.is_input_vectorized:
271
+ hidden_states = self.latent_image_embedding(hidden_states)
272
+
273
+ elif self.is_input_patches:
274
+ hidden_states = self.pos_embed(hidden_states)
275
+
276
+ # 2. Blocks
277
+ for block in self.transformer_blocks:
278
+ hidden_states = block(
279
+ hidden_states,
280
+ attention_mask=attention_mask,
281
+ encoder_hidden_states=encoder_hidden_states,
282
+ encoder_attention_mask=encoder_attention_mask,
283
+ timestep=timestep,
284
+ cross_attention_kwargs=cross_attention_kwargs,
285
+ class_labels=class_labels,
286
+ )
287
+
288
+ # 3. Output
289
+ if self.is_input_continuous:
290
+ if not self.use_linear_projection:
291
+ hidden_states = hidden_states.reshape(
292
+ batch, height, width, inner_dim
293
+ ).permute(0, 3, 1, 2).contiguous()
294
+ hidden_states = self.proj_out(hidden_states)
295
+ else:
296
+ hidden_states = self.proj_out(hidden_states)
297
+ hidden_states = hidden_states.reshape(
298
+ batch, height, width, inner_dim
299
+ ).permute(0, 3, 1, 2).contiguous()
300
+ output = hidden_states + residual
301
+
302
+ elif self.is_input_vectorized:
303
+ hidden_states = self.norm_out(hidden_states)
304
+ logits = self.out(hidden_states)
305
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
306
+ logits = logits.permute(0, 2, 1)
307
+
308
+ # log(p(x_0))
309
+ output = F.log_softmax(logits.double(), dim=1).float()
310
+
311
+ elif self.is_input_patches:
312
+ # TODO: cleanup!
313
+ conditioning = self.transformer_blocks[0].norm1.emb(
314
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
315
+ )
316
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
317
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
318
+ hidden_states = self.proj_out_2(hidden_states)
319
+
320
+ # unpatchify
321
+ height = width = int(hidden_states.shape[1] ** 0.5)
322
+ hidden_states = hidden_states.reshape(
323
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
324
+ )
325
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
326
+ output = hidden_states.reshape(
327
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
328
+ )
329
+
330
+ if not return_dict:
331
+ return (output,)
332
+
333
+ return Transformer2DModelOutput(sample=output)
diffusers/models/unet_2d.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ..utils.configuration_utils import ConfigMixin, register_to_config
21
+ from ..utils.outputs import BaseOutput
22
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
23
+ from .modeling_utils import ModelMixin
24
+ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
25
+
26
+
27
+ @dataclass
28
+ class UNet2DOutput(BaseOutput):
29
+ """
30
+ Args:
31
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
32
+ Hidden states output. Output of last layer of model.
33
+ """
34
+
35
+ sample: torch.FloatTensor
36
+
37
+
38
+ class UNet2DModel(ModelMixin, ConfigMixin):
39
+ r"""
40
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
41
+
42
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
43
+ implements for all the model (such as downloading or saving, etc.)
44
+
45
+ Parameters:
46
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
47
+ Height and width of input/output sample.
48
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
50
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
51
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
52
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
53
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
54
+ obj:`True`): Whether to flip sin to cos for fourier time embedding.
55
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
56
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
57
+ types.
58
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
59
+ The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
60
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
61
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
62
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
63
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
64
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
65
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
66
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
67
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
68
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
69
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
70
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
71
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
72
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
73
+ class_embed_type (`str`, *optional*, defaults to None):
74
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
75
+ `"timestep"`, or `"identity"`.
76
+ num_class_embeds (`int`, *optional*, defaults to None):
77
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
78
+ class conditioning with `class_embed_type` equal to `None`.
79
+ """
80
+
81
+ @register_to_config
82
+ def __init__(
83
+ self,
84
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
85
+ in_channels: int = 3,
86
+ out_channels: int = 3,
87
+ center_input_sample: bool = False,
88
+ time_embedding_type: str = "positional",
89
+ freq_shift: int = 0,
90
+ flip_sin_to_cos: bool = True,
91
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
92
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
93
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
94
+ layers_per_block: int = 2,
95
+ mid_block_scale_factor: float = 1,
96
+ downsample_padding: int = 1,
97
+ act_fn: str = "silu",
98
+ attention_head_dim: Optional[int] = 8,
99
+ norm_num_groups: int = 32,
100
+ norm_eps: float = 1e-5,
101
+ resnet_time_scale_shift: str = "default",
102
+ add_attention: bool = True,
103
+ class_embed_type: Optional[str] = None,
104
+ num_class_embeds: Optional[int] = None,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.sample_size = sample_size
109
+ time_embed_dim = block_out_channels[0] * 4
110
+
111
+ # Check inputs
112
+ if len(down_block_types) != len(up_block_types):
113
+ raise ValueError(
114
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
115
+ )
116
+
117
+ if len(block_out_channels) != len(down_block_types):
118
+ raise ValueError(
119
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
120
+ )
121
+
122
+ # input
123
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
124
+
125
+ # time
126
+ if time_embedding_type == "fourier":
127
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
128
+ timestep_input_dim = 2 * block_out_channels[0]
129
+ elif time_embedding_type == "positional":
130
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
131
+ timestep_input_dim = block_out_channels[0]
132
+
133
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
134
+
135
+ # class embedding
136
+ if class_embed_type is None and num_class_embeds is not None:
137
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
138
+ elif class_embed_type == "timestep":
139
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
140
+ elif class_embed_type == "identity":
141
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
142
+ else:
143
+ self.class_embedding = None
144
+
145
+ self.down_blocks = nn.ModuleList([])
146
+ self.mid_block = None
147
+ self.up_blocks = nn.ModuleList([])
148
+
149
+ # down
150
+ output_channel = block_out_channels[0]
151
+ for i, down_block_type in enumerate(down_block_types):
152
+ input_channel = output_channel
153
+ output_channel = block_out_channels[i]
154
+ is_final_block = i == len(block_out_channels) - 1
155
+
156
+ down_block = get_down_block(
157
+ down_block_type,
158
+ num_layers=layers_per_block,
159
+ in_channels=input_channel,
160
+ out_channels=output_channel,
161
+ temb_channels=time_embed_dim,
162
+ add_downsample=not is_final_block,
163
+ resnet_eps=norm_eps,
164
+ resnet_act_fn=act_fn,
165
+ resnet_groups=norm_num_groups,
166
+ attn_num_head_channels=attention_head_dim,
167
+ downsample_padding=downsample_padding,
168
+ resnet_time_scale_shift=resnet_time_scale_shift,
169
+ )
170
+ self.down_blocks.append(down_block)
171
+
172
+ # mid
173
+ self.mid_block = UNetMidBlock2D(
174
+ in_channels=block_out_channels[-1],
175
+ temb_channels=time_embed_dim,
176
+ resnet_eps=norm_eps,
177
+ resnet_act_fn=act_fn,
178
+ output_scale_factor=mid_block_scale_factor,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attn_num_head_channels=attention_head_dim,
181
+ resnet_groups=norm_num_groups,
182
+ add_attention=add_attention,
183
+ )
184
+
185
+ # up
186
+ reversed_block_out_channels = list(reversed(block_out_channels))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ prev_output_channel = output_channel
190
+ output_channel = reversed_block_out_channels[i]
191
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
192
+
193
+ is_final_block = i == len(block_out_channels) - 1
194
+
195
+ up_block = get_up_block(
196
+ up_block_type,
197
+ num_layers=layers_per_block + 1,
198
+ in_channels=input_channel,
199
+ out_channels=output_channel,
200
+ prev_output_channel=prev_output_channel,
201
+ temb_channels=time_embed_dim,
202
+ add_upsample=not is_final_block,
203
+ resnet_eps=norm_eps,
204
+ resnet_act_fn=act_fn,
205
+ resnet_groups=norm_num_groups,
206
+ attn_num_head_channels=attention_head_dim,
207
+ resnet_time_scale_shift=resnet_time_scale_shift,
208
+ )
209
+ self.up_blocks.append(up_block)
210
+ prev_output_channel = output_channel
211
+
212
+ # out
213
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
214
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
215
+ self.conv_act = nn.SiLU()
216
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
217
+
218
+ def forward(
219
+ self,
220
+ sample: torch.FloatTensor,
221
+ timestep: Union[torch.Tensor, float, int],
222
+ class_labels: Optional[torch.Tensor] = None,
223
+ return_dict: bool = True,
224
+ ) -> Union[UNet2DOutput, Tuple]:
225
+ r"""
226
+ Args:
227
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
228
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
229
+ class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
230
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
231
+ return_dict (`bool`, *optional*, defaults to `True`):
232
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
233
+
234
+ Returns:
235
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
236
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
237
+ """
238
+ # 0. center input if necessary
239
+ if self.config.center_input_sample:
240
+ sample = 2 * sample - 1.0
241
+
242
+ # 1. time
243
+ timesteps = timestep
244
+ if not torch.is_tensor(timesteps):
245
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
246
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
247
+ timesteps = timesteps[None].to(sample.device)
248
+
249
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
250
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
251
+
252
+ t_emb = self.time_proj(timesteps)
253
+
254
+ # timesteps does not contain any weights and will always return f32 tensors
255
+ # but time_embedding might actually be running in fp16. so we need to cast here.
256
+ # there might be better ways to encapsulate this.
257
+ t_emb = t_emb.to(dtype=self.dtype)
258
+ emb = self.time_embedding(t_emb)
259
+
260
+ if self.class_embedding is not None:
261
+ if class_labels is None:
262
+ raise ValueError("class_labels should be provided when doing class conditioning")
263
+
264
+ if self.config.class_embed_type == "timestep":
265
+ class_labels = self.time_proj(class_labels)
266
+
267
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
268
+ emb = emb + class_emb
269
+
270
+ # 2. pre-process
271
+ skip_sample = sample
272
+ sample = self.conv_in(sample)
273
+
274
+ # 3. down
275
+ down_block_res_samples = (sample,)
276
+ for downsample_block in self.down_blocks:
277
+ if hasattr(downsample_block, "skip_conv"):
278
+ sample, res_samples, skip_sample = downsample_block(
279
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
280
+ )
281
+ else:
282
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
283
+
284
+ down_block_res_samples += res_samples
285
+
286
+ # 4. mid
287
+ sample = self.mid_block(sample, emb)
288
+
289
+ # 5. up
290
+ skip_sample = None
291
+ for upsample_block in self.up_blocks:
292
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
293
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
294
+
295
+ if hasattr(upsample_block, "skip_conv"):
296
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
297
+ else:
298
+ sample = upsample_block(sample, res_samples, emb)
299
+
300
+ # 6. post-process
301
+ sample = self.conv_norm_out(sample)
302
+ sample = self.conv_act(sample)
303
+ sample = self.conv_out(sample)
304
+
305
+ if skip_sample is not None:
306
+ sample += skip_sample
307
+
308
+ if self.config.time_embedding_type == "fourier":
309
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
310
+ sample = sample / timesteps
311
+
312
+ if not return_dict:
313
+ return (sample,)
314
+
315
+ return UNet2DOutput(sample=sample)
diffusers/models/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
diffusers/models/unet_2d_condition.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ..utils.configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils.outputs import BaseOutput
24
+ from .loaders import UNet2DConditionLoadersMixin
25
+ from .activations import get_activation
26
+ from .attention_processor import AttentionProcessor, AttnProcessor
27
+ from .embeddings import (
28
+ GaussianFourierProjection,
29
+ TextImageProjection,
30
+ TextImageTimeEmbedding,
31
+ TextTimeEmbedding,
32
+ TimestepEmbedding,
33
+ Timesteps,
34
+ )
35
+ from .modeling_utils import ModelMixin
36
+ from .unet_2d_blocks import (
37
+ CrossAttnDownBlock2D,
38
+ CrossAttnUpBlock2D,
39
+ DownBlock2D,
40
+ UNetMidBlock2DCrossAttn,
41
+ UNetMidBlock2DSimpleCrossAttn,
42
+ UpBlock2D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from ..utils.logging import get_logger
47
+
48
+ logger = get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ @dataclass
52
+ class UNet2DConditionOutput(BaseOutput):
53
+ """
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
56
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor
60
+
61
+
62
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
63
+ r"""
64
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
65
+ and returns sample shaped output.
66
+
67
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
68
+ implements for all the models (such as downloading or saving, etc.)
69
+
70
+ Parameters:
71
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
72
+ Height and width of input/output sample.
73
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
74
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
75
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
76
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
77
+ Whether to flip the sin to cos in the time embedding.
78
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
79
+ down_block_types (`Tuple[str]`, *optional*,
80
+ defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
84
+ mid block layer if `None`.
85
+ up_block_types (`Tuple[str]`, *optional*,
86
+ defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
87
+ The tuple of upsample blocks to use.
88
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
89
+ Whether to include self-attention in the basic transformer blocks, see
90
+ [`~models.attention.BasicTransformerBlock`].
91
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
92
+ The tuple of output channels for each block.
93
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
94
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
95
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, it will skip the normalization and activation layers in post-processing
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ encoder_hid_dim (`int`, *optional*, defaults to None):
103
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
104
+ dimension to `cross_attention_dim`.
105
+ encoder_hid_dim_type (`str`, *optional*, defaults to None):
106
+ If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
107
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
108
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
109
+ num_attention_heads (`int`, *optional*):
110
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
111
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
112
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
113
+ class_embed_type (`str`, *optional*, defaults to None):
114
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
115
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
116
+ addition_embed_type (`str`, *optional*, defaults to None):
117
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
118
+ "text". "text" will use the `TextTimeEmbedding` layer.
119
+ num_class_embeds (`int`, *optional*, defaults to None):
120
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
121
+ class conditioning with `class_embed_type` equal to `None`.
122
+ time_embedding_type (`str`, *optional*, default to `positional`):
123
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
124
+ time_embedding_dim (`int`, *optional*, default to `None`):
125
+ An optional override for the dimension of the projected time embedding.
126
+ time_embedding_act_fn (`str`, *optional*, default to `None`):
127
+ Optional activation function to use on the time embeddings only one time before they as passed to the rest
128
+ of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
129
+ timestep_post_act (`str, *optional*, default to `None`):
130
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
131
+ time_cond_proj_dim (`int`, *optional*, default to `None`):
132
+ The dimension of `cond_proj` layer in timestep embedding.
133
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
134
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
135
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
136
+ using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
137
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
138
+ embeddings with the class embeddings.
139
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
140
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
141
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
142
+ `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
143
+ default to `False`.
144
+ """
145
+
146
+ _supports_gradient_checkpointing = True
147
+
148
+ @register_to_config
149
+ def __init__(
150
+ self,
151
+ sample_size: Optional[int] = None,
152
+ in_channels: int = 4,
153
+ out_channels: int = 4,
154
+ center_input_sample: bool = False,
155
+ flip_sin_to_cos: bool = True,
156
+ freq_shift: int = 0,
157
+ down_block_types: Tuple[str] = (
158
+ "CrossAttnDownBlock2D",
159
+ "CrossAttnDownBlock2D",
160
+ "CrossAttnDownBlock2D",
161
+ "DownBlock2D",
162
+ ),
163
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
164
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
165
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
166
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
167
+ layers_per_block: Union[int, Tuple[int]] = 2,
168
+ downsample_padding: int = 1,
169
+ mid_block_scale_factor: float = 1,
170
+ act_fn: str = "silu",
171
+ norm_num_groups: Optional[int] = 32,
172
+ norm_eps: float = 1e-5,
173
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
174
+ encoder_hid_dim: Optional[int] = None,
175
+ encoder_hid_dim_type: Optional[str] = None,
176
+ attention_head_dim: Union[int, Tuple[int]] = 8,
177
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
178
+ dual_cross_attention: bool = False,
179
+ use_linear_projection: bool = False,
180
+ class_embed_type: Optional[str] = None,
181
+ addition_embed_type: Optional[str] = None,
182
+ num_class_embeds: Optional[int] = None,
183
+ upcast_attention: bool = False,
184
+ resnet_time_scale_shift: str = "default",
185
+ resnet_skip_time_act: bool = False,
186
+ resnet_out_scale_factor: int = 1.0,
187
+ time_embedding_type: str = "positional",
188
+ time_embedding_dim: Optional[int] = None,
189
+ time_embedding_act_fn: Optional[str] = None,
190
+ timestep_post_act: Optional[str] = None,
191
+ time_cond_proj_dim: Optional[int] = None,
192
+ conv_in_kernel: int = 3,
193
+ conv_out_kernel: int = 3,
194
+ projection_class_embeddings_input_dim: Optional[int] = None,
195
+ class_embeddings_concat: bool = False,
196
+ mid_block_only_cross_attention: Optional[bool] = None,
197
+ cross_attention_norm: Optional[str] = None,
198
+ addition_embed_type_num_heads=64,
199
+ ):
200
+ super().__init__()
201
+
202
+ self.sample_size = sample_size
203
+
204
+ # If `num_attention_heads` is not defined (which is the case for most models)
205
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
206
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
207
+ # when this library was created. The incorrect naming was only discovered much later in
208
+ # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
209
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
210
+ # which is why we correct for the naming here.
211
+ num_attention_heads = num_attention_heads or attention_head_dim
212
+
213
+ # Check inputs
214
+ if len(down_block_types) != len(up_block_types):
215
+ raise ValueError(
216
+ "Must provide the same number of `down_block_types` as `up_block_types`. "
217
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
218
+ )
219
+
220
+ if len(block_out_channels) != len(down_block_types):
221
+ raise ValueError(
222
+ "Must provide the same number of `block_out_channels` as `down_block_types`. "
223
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
224
+ )
225
+
226
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
227
+ raise ValueError(
228
+ "Must provide the same number of `only_cross_attention` as `down_block_types`. "
229
+ f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
230
+ )
231
+
232
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
233
+ raise ValueError(
234
+ "Must provide the same number of `num_attention_heads` as `down_block_types`. "
235
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
236
+ )
237
+
238
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
239
+ raise ValueError(
240
+ "Must provide the same number of `attention_head_dim` as `down_block_types`. "
241
+ f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
242
+ )
243
+
244
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
245
+ raise ValueError(
246
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`. "
247
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
248
+ )
249
+
250
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
251
+ raise ValueError(
252
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. "
253
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ # input
257
+ conv_in_padding = (conv_in_kernel - 1) // 2
258
+ self.conv_in = nn.Conv2d(
259
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
260
+ )
261
+
262
+ # time
263
+ if time_embedding_type == "fourier":
264
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
265
+ if time_embed_dim % 2 != 0:
266
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
267
+ self.time_proj = GaussianFourierProjection(
268
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
269
+ )
270
+ timestep_input_dim = time_embed_dim
271
+ elif time_embedding_type == "positional":
272
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
273
+
274
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
275
+ timestep_input_dim = block_out_channels[0]
276
+ else:
277
+ raise ValueError(
278
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
279
+ )
280
+
281
+ self.time_embedding = TimestepEmbedding(
282
+ timestep_input_dim,
283
+ time_embed_dim,
284
+ act_fn=act_fn,
285
+ post_act_fn=timestep_post_act,
286
+ cond_proj_dim=time_cond_proj_dim,
287
+ )
288
+
289
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
290
+ encoder_hid_dim_type = "text_proj"
291
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
292
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
293
+
294
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
295
+ raise ValueError(
296
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
297
+ )
298
+
299
+ if encoder_hid_dim_type == "text_proj":
300
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
301
+ elif encoder_hid_dim_type == "text_image_proj":
302
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
303
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently
304
+ # only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
305
+ self.encoder_hid_proj = TextImageProjection(
306
+ text_embed_dim=encoder_hid_dim,
307
+ image_embed_dim=cross_attention_dim,
308
+ cross_attention_dim=cross_attention_dim,
309
+ )
310
+
311
+ elif encoder_hid_dim_type is not None:
312
+ raise ValueError(
313
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
314
+ )
315
+ else:
316
+ self.encoder_hid_proj = None
317
+
318
+ # class embedding
319
+ if class_embed_type is None and num_class_embeds is not None:
320
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
321
+ elif class_embed_type == "timestep":
322
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
323
+ elif class_embed_type == "identity":
324
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
325
+ elif class_embed_type == "projection":
326
+ if projection_class_embeddings_input_dim is None:
327
+ raise ValueError(
328
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
329
+ )
330
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
331
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
332
+ # 2. it projects from an arbitrary input dimension.
333
+ #
334
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
335
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
336
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
337
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
338
+ elif class_embed_type == "simple_projection":
339
+ if projection_class_embeddings_input_dim is None:
340
+ raise ValueError(
341
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
342
+ )
343
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
344
+ else:
345
+ self.class_embedding = None
346
+
347
+ if addition_embed_type == "text":
348
+ if encoder_hid_dim is not None:
349
+ text_time_embedding_from_dim = encoder_hid_dim
350
+ else:
351
+ text_time_embedding_from_dim = cross_attention_dim
352
+
353
+ self.add_embedding = TextTimeEmbedding(
354
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
355
+ )
356
+ elif addition_embed_type == "text_image":
357
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__
358
+ # too much, they are set to `cross_attention_dim` here as this is exactly the required dimension for the
359
+ # currently only use case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
360
+ self.add_embedding = TextImageTimeEmbedding(
361
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
362
+ )
363
+ elif addition_embed_type is not None:
364
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
365
+
366
+ if time_embedding_act_fn is None:
367
+ self.time_embed_act = None
368
+ else:
369
+ self.time_embed_act = get_activation(time_embedding_act_fn)
370
+
371
+ self.down_blocks = nn.ModuleList([])
372
+ self.up_blocks = nn.ModuleList([])
373
+
374
+ if isinstance(only_cross_attention, bool):
375
+ if mid_block_only_cross_attention is None:
376
+ mid_block_only_cross_attention = only_cross_attention
377
+
378
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
379
+
380
+ if mid_block_only_cross_attention is None:
381
+ mid_block_only_cross_attention = False
382
+
383
+ if isinstance(num_attention_heads, int):
384
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
385
+
386
+ if isinstance(attention_head_dim, int):
387
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
388
+
389
+ if isinstance(cross_attention_dim, int):
390
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
391
+
392
+ if isinstance(layers_per_block, int):
393
+ layers_per_block = [layers_per_block] * len(down_block_types)
394
+
395
+ if class_embeddings_concat:
396
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
397
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
398
+ # regular time embeddings
399
+ blocks_time_embed_dim = time_embed_dim * 2
400
+ else:
401
+ blocks_time_embed_dim = time_embed_dim
402
+
403
+ # down
404
+ output_channel = block_out_channels[0]
405
+ for i, down_block_type in enumerate(down_block_types):
406
+ input_channel = output_channel
407
+ output_channel = block_out_channels[i]
408
+ is_final_block = i == len(block_out_channels) - 1
409
+
410
+ down_block = get_down_block(
411
+ down_block_type,
412
+ num_layers=layers_per_block[i],
413
+ in_channels=input_channel,
414
+ out_channels=output_channel,
415
+ temb_channels=blocks_time_embed_dim,
416
+ add_downsample=not is_final_block,
417
+ resnet_eps=norm_eps,
418
+ resnet_act_fn=act_fn,
419
+ resnet_groups=norm_num_groups,
420
+ cross_attention_dim=cross_attention_dim[i],
421
+ num_attention_heads=num_attention_heads[i],
422
+ downsample_padding=downsample_padding,
423
+ dual_cross_attention=dual_cross_attention,
424
+ use_linear_projection=use_linear_projection,
425
+ only_cross_attention=only_cross_attention[i],
426
+ upcast_attention=upcast_attention,
427
+ resnet_time_scale_shift=resnet_time_scale_shift,
428
+ resnet_skip_time_act=resnet_skip_time_act,
429
+ resnet_out_scale_factor=resnet_out_scale_factor,
430
+ cross_attention_norm=cross_attention_norm,
431
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
432
+ )
433
+ self.down_blocks.append(down_block)
434
+
435
+ # mid
436
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
437
+ self.mid_block = UNetMidBlock2DCrossAttn(
438
+ in_channels=block_out_channels[-1],
439
+ temb_channels=blocks_time_embed_dim,
440
+ resnet_eps=norm_eps,
441
+ resnet_act_fn=act_fn,
442
+ output_scale_factor=mid_block_scale_factor,
443
+ resnet_time_scale_shift=resnet_time_scale_shift,
444
+ cross_attention_dim=cross_attention_dim[-1],
445
+ num_attention_heads=num_attention_heads[-1],
446
+ resnet_groups=norm_num_groups,
447
+ dual_cross_attention=dual_cross_attention,
448
+ use_linear_projection=use_linear_projection,
449
+ upcast_attention=upcast_attention,
450
+ )
451
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
452
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
453
+ in_channels=block_out_channels[-1],
454
+ temb_channels=blocks_time_embed_dim,
455
+ resnet_eps=norm_eps,
456
+ resnet_act_fn=act_fn,
457
+ output_scale_factor=mid_block_scale_factor,
458
+ cross_attention_dim=cross_attention_dim[-1],
459
+ attention_head_dim=attention_head_dim[-1],
460
+ resnet_groups=norm_num_groups,
461
+ resnet_time_scale_shift=resnet_time_scale_shift,
462
+ skip_time_act=resnet_skip_time_act,
463
+ only_cross_attention=mid_block_only_cross_attention,
464
+ cross_attention_norm=cross_attention_norm,
465
+ )
466
+ elif mid_block_type is None:
467
+ self.mid_block = None
468
+ else:
469
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
470
+
471
+ # count how many layers upsample the images
472
+ self.num_upsamplers = 0
473
+
474
+ # up
475
+ reversed_block_out_channels = list(reversed(block_out_channels))
476
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
477
+ reversed_layers_per_block = list(reversed(layers_per_block))
478
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
479
+ only_cross_attention = list(reversed(only_cross_attention))
480
+
481
+ output_channel = reversed_block_out_channels[0]
482
+ for i, up_block_type in enumerate(up_block_types):
483
+ is_final_block = i == len(block_out_channels) - 1
484
+
485
+ prev_output_channel = output_channel
486
+ output_channel = reversed_block_out_channels[i]
487
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
488
+
489
+ # add upsample block for all BUT final layer
490
+ if not is_final_block:
491
+ add_upsample = True
492
+ self.num_upsamplers += 1
493
+ else:
494
+ add_upsample = False
495
+
496
+ up_block = get_up_block(
497
+ up_block_type,
498
+ num_layers=reversed_layers_per_block[i] + 1,
499
+ in_channels=input_channel,
500
+ out_channels=output_channel,
501
+ prev_output_channel=prev_output_channel,
502
+ temb_channels=blocks_time_embed_dim,
503
+ add_upsample=add_upsample,
504
+ resnet_eps=norm_eps,
505
+ resnet_act_fn=act_fn,
506
+ resnet_groups=norm_num_groups,
507
+ cross_attention_dim=reversed_cross_attention_dim[i],
508
+ num_attention_heads=reversed_num_attention_heads[i],
509
+ dual_cross_attention=dual_cross_attention,
510
+ use_linear_projection=use_linear_projection,
511
+ only_cross_attention=only_cross_attention[i],
512
+ upcast_attention=upcast_attention,
513
+ resnet_time_scale_shift=resnet_time_scale_shift,
514
+ resnet_skip_time_act=resnet_skip_time_act,
515
+ resnet_out_scale_factor=resnet_out_scale_factor,
516
+ cross_attention_norm=cross_attention_norm,
517
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
518
+ )
519
+ self.up_blocks.append(up_block)
520
+ prev_output_channel = output_channel
521
+
522
+ # out
523
+ if norm_num_groups is not None:
524
+ self.conv_norm_out = nn.GroupNorm(
525
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
526
+ )
527
+
528
+ self.conv_act = get_activation(act_fn)
529
+
530
+ else:
531
+ self.conv_norm_out = None
532
+ self.conv_act = None
533
+
534
+ conv_out_padding = (conv_out_kernel - 1) // 2
535
+ self.conv_out = nn.Conv2d(
536
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
537
+ )
538
+
539
+ @property
540
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
541
+ r"""
542
+ Returns:
543
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
544
+ indexed by its weight name.
545
+ """
546
+ # set recursively
547
+ processors = {}
548
+
549
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
550
+ if hasattr(module, "set_processor"):
551
+ processors[f"{name}.processor"] = module.processor
552
+
553
+ for sub_name, child in module.named_children():
554
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
555
+
556
+ return processors
557
+
558
+ for name, module in self.named_children():
559
+ fn_recursive_add_processors(name, module, processors)
560
+
561
+ return processors
562
+
563
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
564
+ r"""
565
+ Parameters:
566
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
567
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
568
+ of **all** `Attention` layers.
569
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor.
570
+ This is strongly recommended when setting trainable attention processors.
571
+ """
572
+ count = len(self.attn_processors.keys())
573
+
574
+ if isinstance(processor, dict) and len(processor) != count:
575
+ raise ValueError(
576
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
577
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
578
+ )
579
+
580
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
581
+ if hasattr(module, "set_processor"):
582
+ if not isinstance(processor, dict):
583
+ module.set_processor(processor)
584
+ else:
585
+ module.set_processor(processor.pop(f"{name}.processor"))
586
+
587
+ for sub_name, child in module.named_children():
588
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
589
+
590
+ for name, module in self.named_children():
591
+ fn_recursive_attn_processor(name, module, processor)
592
+
593
+ def set_default_attn_processor(self):
594
+ """
595
+ Disables custom attention processors and sets the default attention implementation.
596
+ """
597
+ self.set_attn_processor(AttnProcessor())
598
+
599
+ def set_attention_slice(self, slice_size):
600
+ r"""
601
+ Enable sliced attention computation.
602
+
603
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
604
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
605
+
606
+ Args:
607
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
608
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
609
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
610
+ provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
611
+ `num_attention_heads` must be a multiple of `slice_size`.
612
+ """
613
+ sliceable_head_dims = []
614
+
615
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
616
+ if hasattr(module, "set_attention_slice"):
617
+ sliceable_head_dims.append(module.sliceable_head_dim)
618
+
619
+ for child in module.children():
620
+ fn_recursive_retrieve_sliceable_dims(child)
621
+
622
+ # retrieve number of attention layers
623
+ for module in self.children():
624
+ fn_recursive_retrieve_sliceable_dims(module)
625
+
626
+ num_sliceable_layers = len(sliceable_head_dims)
627
+
628
+ if slice_size == "auto":
629
+ # half the attention head size is usually a good trade-off between
630
+ # speed and memory
631
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
632
+ elif slice_size == "max":
633
+ # make smallest slice possible
634
+ slice_size = num_sliceable_layers * [1]
635
+
636
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
637
+
638
+ if len(slice_size) != len(sliceable_head_dims):
639
+ raise ValueError(
640
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
641
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
642
+ )
643
+
644
+ for i in range(len(slice_size)):
645
+ size = slice_size[i]
646
+ dim = sliceable_head_dims[i]
647
+ if size is not None and size > dim:
648
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
649
+
650
+ # Recursively walk through all the children.
651
+ # Any children which exposes the set_attention_slice method
652
+ # gets the message
653
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
654
+ if hasattr(module, "set_attention_slice"):
655
+ module.set_attention_slice(slice_size.pop())
656
+
657
+ for child in module.children():
658
+ fn_recursive_set_attention_slice(child, slice_size)
659
+
660
+ reversed_slice_size = list(reversed(slice_size))
661
+ for module in self.children():
662
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
663
+
664
+ def _set_gradient_checkpointing(self, module, value=False):
665
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
666
+ module.gradient_checkpointing = value
667
+
668
+ def forward(
669
+ self,
670
+ sample: torch.FloatTensor,
671
+ timestep: Union[torch.Tensor, float, int],
672
+ encoder_hidden_states: torch.Tensor,
673
+ class_labels: Optional[torch.Tensor] = None,
674
+ timestep_cond: Optional[torch.Tensor] = None,
675
+ attention_mask: Optional[torch.Tensor] = None,
676
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
677
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
678
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
679
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
680
+ encoder_attention_mask: Optional[torch.Tensor] = None,
681
+ return_dict: bool = True,
682
+ **kwargs
683
+ ) -> Union[UNet2DConditionOutput, Tuple]:
684
+ r"""
685
+ Args:
686
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
687
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
688
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
689
+ encoder_attention_mask (`torch.Tensor`):
690
+ (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
691
+ discard. Mask will be converted into a bias, which adds large negative values to attention scores
692
+ corresponding to "discard" tokens.
693
+ return_dict (`bool`, *optional*, defaults to `True`):
694
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
695
+ cross_attention_kwargs (`dict`, *optional*):
696
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
697
+ `self.processor` in [diffusers.cross_attention]
698
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
699
+ added_cond_kwargs (`dict`, *optional*):
700
+ A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
701
+ embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
702
+ `addition_embed_type` for more information.
703
+
704
+ Returns:
705
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
706
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
707
+ returning a tuple, the first element is the sample tensor.
708
+ """
709
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
710
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
711
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
712
+ # on the fly if necessary.
713
+ default_overall_up_factor = 2**self.num_upsamplers
714
+
715
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
716
+ forward_upsample_size = False
717
+ upsample_size = None
718
+
719
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
720
+ logger.info("Forward upsample size to force interpolation output size.")
721
+ forward_upsample_size = True
722
+
723
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
724
+ # expects mask of shape:
725
+ # [batch, key_tokens]
726
+ # adds singleton query_tokens dimension:
727
+ # [batch, 1, key_tokens]
728
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
729
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
730
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
731
+ if attention_mask is not None:
732
+ # assume that mask is expressed as:
733
+ # (1 = keep, 0 = discard)
734
+ # convert mask into a bias that can be added to attention scores:
735
+ # (keep = +0, discard = -10000.0)
736
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
737
+ attention_mask = attention_mask.unsqueeze(1)
738
+
739
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
740
+ if encoder_attention_mask is not None:
741
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
742
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
743
+
744
+ # 0. center input if necessary
745
+ if self.config.center_input_sample:
746
+ sample = 2 * sample - 1.0
747
+
748
+ # 1. time
749
+ timesteps = timestep
750
+ if not torch.is_tensor(timesteps):
751
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
752
+ # This would be a good case for the `match` statement (Python 3.10+)
753
+ is_mps = sample.device.type == "mps"
754
+ if isinstance(timestep, float):
755
+ dtype = torch.float32 if is_mps else torch.float64
756
+ else:
757
+ dtype = torch.int32 if is_mps else torch.int64
758
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
759
+ elif len(timesteps.shape) == 0:
760
+ timesteps = timesteps[None].to(sample.device)
761
+
762
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
763
+ timesteps = timesteps.expand(sample.shape[0])
764
+
765
+ t_emb = self.time_proj(timesteps)
766
+
767
+ # `Timesteps` does not contain any weights and will always return f32 tensors
768
+ # but time_embedding might actually be running in fp16. so we need to cast here.
769
+ # there might be better ways to encapsulate this.
770
+ t_emb = t_emb.to(dtype=sample.dtype)
771
+
772
+ emb = self.time_embedding(t_emb, timestep_cond)
773
+
774
+ if self.class_embedding is not None:
775
+ if class_labels is None:
776
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
777
+
778
+ if self.config.class_embed_type == "timestep":
779
+ class_labels = self.time_proj(class_labels)
780
+
781
+ # `Timesteps` does not contain any weights and will always return f32 tensors
782
+ # there might be better ways to encapsulate this.
783
+ class_labels = class_labels.to(dtype=sample.dtype)
784
+
785
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
786
+
787
+ if self.config.class_embeddings_concat:
788
+ emb = torch.cat([emb, class_emb], dim=-1)
789
+ else:
790
+ emb = emb + class_emb
791
+
792
+ if self.config.addition_embed_type == "text":
793
+ aug_emb = self.add_embedding(encoder_hidden_states)
794
+ emb = emb + aug_emb
795
+ elif self.config.addition_embed_type == "text_image":
796
+ # Kadinsky 2.1 - style
797
+ if "image_embeds" not in added_cond_kwargs:
798
+ raise ValueError(
799
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
800
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
801
+ )
802
+
803
+ image_embs = added_cond_kwargs.get("image_embeds")
804
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
805
+
806
+ aug_emb = self.add_embedding(text_embs, image_embs)
807
+ emb = emb + aug_emb
808
+
809
+ if self.time_embed_act is not None:
810
+ emb = self.time_embed_act(emb)
811
+
812
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
813
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
814
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
815
+ # Kadinsky 2.1 - style
816
+ if "image_embeds" not in added_cond_kwargs:
817
+ raise ValueError(
818
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
819
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
820
+ )
821
+
822
+ image_embeds = added_cond_kwargs.get("image_embeds")
823
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
824
+
825
+ # 2. pre-process
826
+ sample = self.conv_in(sample)
827
+
828
+ # 3. down
829
+ down_block_res_samples = (sample,)
830
+ for downsample_block in self.down_blocks:
831
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
832
+ sample, res_samples = downsample_block(
833
+ hidden_states=sample,
834
+ temb=emb,
835
+ encoder_hidden_states=encoder_hidden_states,
836
+ attention_mask=attention_mask,
837
+ cross_attention_kwargs=cross_attention_kwargs,
838
+ encoder_attention_mask=encoder_attention_mask,
839
+ )
840
+ else:
841
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
842
+
843
+ down_block_res_samples += res_samples
844
+
845
+ if down_block_additional_residuals is not None:
846
+ new_down_block_res_samples = ()
847
+
848
+ for down_block_res_sample, down_block_additional_residual in zip(
849
+ down_block_res_samples, down_block_additional_residuals
850
+ ):
851
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
852
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
853
+
854
+ down_block_res_samples = new_down_block_res_samples
855
+
856
+ # 4. mid
857
+ if self.mid_block is not None:
858
+ sample = self.mid_block(
859
+ sample,
860
+ emb,
861
+ encoder_hidden_states=encoder_hidden_states,
862
+ attention_mask=attention_mask,
863
+ cross_attention_kwargs=cross_attention_kwargs,
864
+ encoder_attention_mask=encoder_attention_mask,
865
+ )
866
+
867
+ if mid_block_additional_residual is not None:
868
+ sample = sample + mid_block_additional_residual
869
+
870
+ # 5. up
871
+ for i, upsample_block in enumerate(self.up_blocks):
872
+ is_final_block = i == len(self.up_blocks) - 1
873
+
874
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
875
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
876
+
877
+ # if we have not reached the final block and need to forward the
878
+ # upsample size, we do it here
879
+ if not is_final_block and forward_upsample_size:
880
+ upsample_size = down_block_res_samples[-1].shape[2:]
881
+
882
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
883
+ sample = upsample_block(
884
+ hidden_states=sample,
885
+ temb=emb,
886
+ res_hidden_states_tuple=res_samples,
887
+ encoder_hidden_states=encoder_hidden_states,
888
+ cross_attention_kwargs=cross_attention_kwargs,
889
+ upsample_size=upsample_size,
890
+ attention_mask=attention_mask,
891
+ encoder_attention_mask=encoder_attention_mask,
892
+ )
893
+ else:
894
+ sample = upsample_block(
895
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
896
+ )
897
+
898
+ # 6. post-process
899
+ if self.conv_norm_out:
900
+ sample = self.conv_norm_out(sample)
901
+ sample = self.conv_act(sample)
902
+ sample = self.conv_out(sample)
903
+
904
+ if not return_dict:
905
+ return (sample,)
906
+
907
+ return UNet2DConditionOutput(sample=sample)
diffusers/models/unet_2d_condition_guided.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from ..utils.configuration_utils import ConfigMixin, register_to_config
23
+ from ..utils import logging
24
+ from .loaders import UNet2DConditionLoadersMixin
25
+ from .activations import get_activation
26
+ from .attention_processor import AttentionProcessor, AttnProcessor
27
+ from .embeddings import (
28
+ GaussianFourierProjection,
29
+ TextImageProjection,
30
+ TextImageTimeEmbedding,
31
+ TextTimeEmbedding,
32
+ TimestepEmbedding,
33
+ Timesteps,
34
+ )
35
+ from .modeling_utils import ModelMixin
36
+ from .unet_2d_blocks import (
37
+ CrossAttnDownBlock2D,
38
+ CrossAttnUpBlock2D,
39
+ DownBlock2D,
40
+ UNetMidBlock2DCrossAttn,
41
+ UNetMidBlock2DSimpleCrossAttn,
42
+ UpBlock2D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from .unet_2d_condition import UNet2DConditionOutput
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ class UNet2DConditionGuidedModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
52
+ r"""
53
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample,
54
+ conditional state, and a timestep and returns sample shaped output.
55
+
56
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic
57
+ methods the library implements for all the models (such as downloading or saving, etc.)
58
+
59
+ Parameters:
60
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
61
+ Height and width of input/output sample.
62
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
63
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
64
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
65
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
66
+ Whether to flip the sin to cos in the time embedding.
67
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
68
+ down_block_types (`Tuple[str]`, *optional*, defaults to
69
+ `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
70
+ The tuple of downsample blocks to use.
71
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
72
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`,
73
+ will skip the mid block layer if `None`.
74
+ up_block_types (`Tuple[str]`, *optional*, defaults to
75
+ `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
76
+ The tuple of upsample blocks to use.
77
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
78
+ Whether to include self-attention in the basic transformer blocks, see
79
+ [`~models.attention.BasicTransformerBlock`].
80
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
81
+ The tuple of output channels for each block.
82
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
83
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
84
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
85
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
86
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
87
+ If `None`, it will skip the normalization and activation layers in post-processing
88
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
89
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
90
+ The dimension of the cross attention features.
91
+ encoder_hid_dim (`int`, *optional*, defaults to None):
92
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
93
+ dimension to `cross_attention_dim`.
94
+ encoder_hid_dim_type (`str`, *optional*, defaults to None):
95
+ If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
96
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
97
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
98
+ num_attention_heads (`int`, *optional*):
99
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
100
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
101
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
102
+ class_embed_type (`str`, *optional*, defaults to None):
103
+ The type of class embedding to use which is ultimately summed with the time embeddings.
104
+ Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
105
+ addition_embed_type (`str`, *optional*, defaults to None):
106
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
107
+ "text". "text" will use the `TextTimeEmbedding` layer.
108
+ num_class_embeds (`int`, *optional*, defaults to None):
109
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
110
+ class conditioning with `class_embed_type` equal to `None`.
111
+ time_embedding_type (`str`, *optional*, default to `positional`):
112
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
113
+ time_embedding_dim (`int`, *optional*, default to `None`):
114
+ An optional override for the dimension of the projected time embedding.
115
+ time_embedding_act_fn (`str`, *optional*, default to `None`):
116
+ Optional activation function to use on the time embeddings only one time before they as passed
117
+ to the rest of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
118
+ timestep_post_act (`str, *optional*, default to `None`):
119
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
120
+ time_cond_proj_dim (`int`, *optional*, default to `None`):
121
+ The dimension of `cond_proj` layer in timestep embedding.
122
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
123
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
124
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
125
+ using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
126
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
127
+ embeddings with the class embeddings.
128
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
129
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
130
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
131
+ `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`.
132
+ Else, it will default to `False`.
133
+ """
134
+
135
+ _supports_gradient_checkpointing = True
136
+
137
+ @register_to_config
138
+ def __init__(
139
+ self,
140
+ sample_size: Optional[int] = None,
141
+ in_channels: int = 4,
142
+ out_channels: int = 4,
143
+ center_input_sample: bool = False,
144
+ flip_sin_to_cos: bool = True,
145
+ freq_shift: int = 0,
146
+ down_block_types: Tuple[str] = (
147
+ "CrossAttnDownBlock2D",
148
+ "CrossAttnDownBlock2D",
149
+ "CrossAttnDownBlock2D",
150
+ "DownBlock2D",
151
+ ),
152
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
153
+ up_block_types: Tuple[str] = (
154
+ "UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
155
+ ),
156
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
157
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
158
+ layers_per_block: Union[int, Tuple[int]] = 2,
159
+ downsample_padding: int = 1,
160
+ mid_block_scale_factor: float = 1,
161
+ act_fn: str = "silu",
162
+ norm_num_groups: Optional[int] = 32,
163
+ norm_eps: float = 1e-5,
164
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
165
+ encoder_hid_dim: Optional[int] = None,
166
+ encoder_hid_dim_type: Optional[str] = None,
167
+ attention_head_dim: Union[int, Tuple[int]] = 8,
168
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
169
+ dual_cross_attention: bool = False,
170
+ use_linear_projection: bool = False,
171
+ class_embed_type: Optional[str] = None,
172
+ addition_embed_type: Optional[str] = None,
173
+ num_class_embeds: Optional[int] = None,
174
+ upcast_attention: bool = False,
175
+ resnet_time_scale_shift: str = "default",
176
+ resnet_skip_time_act: bool = False,
177
+ resnet_out_scale_factor: int = 1.0,
178
+ time_embedding_type: str = "positional",
179
+ time_embedding_dim: Optional[int] = None,
180
+ time_embedding_act_fn: Optional[str] = None,
181
+ timestep_post_act: Optional[str] = None,
182
+ time_cond_proj_dim: Optional[int] = None,
183
+ guidance_embedding_type: str = "fourier",
184
+ guidance_embedding_dim: Optional[int] = None,
185
+ guidance_post_act: Optional[str] = None,
186
+ guidance_cond_proj_dim: Optional[int] = None,
187
+ conv_in_kernel: int = 3,
188
+ conv_out_kernel: int = 3,
189
+ projection_class_embeddings_input_dim: Optional[int] = None,
190
+ class_embeddings_concat: bool = False,
191
+ mid_block_only_cross_attention: Optional[bool] = None,
192
+ cross_attention_norm: Optional[str] = None,
193
+ addition_embed_type_num_heads=64,
194
+ ):
195
+ super().__init__()
196
+
197
+ self.sample_size = sample_size
198
+
199
+ # If `num_attention_heads` is not defined (which is the case for most models)
200
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
201
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
202
+ # when this library was created. The incorrect naming was only discovered much later in
203
+ # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
204
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too
205
+ # backwards breaking which is why we correct for the naming here.
206
+ num_attention_heads = num_attention_heads or attention_head_dim
207
+
208
+ # Check inputs
209
+ if len(down_block_types) != len(up_block_types):
210
+ raise ValueError(
211
+ "Must provide the same number of `down_block_types` as `up_block_types`. "
212
+ f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
213
+ )
214
+
215
+ if len(block_out_channels) != len(down_block_types):
216
+ raise ValueError(
217
+ "Must provide the same number of `block_out_channels` as `down_block_types`. "
218
+ f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
219
+ )
220
+
221
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
222
+ raise ValueError(
223
+ "Must provide the same number of `only_cross_attention` as `down_block_types`. "
224
+ f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
225
+ )
226
+
227
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
228
+ raise ValueError(
229
+ "Must provide the same number of `num_attention_heads` as `down_block_types`. "
230
+ f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
231
+ )
232
+
233
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
234
+ raise ValueError(
235
+ "Must provide the same number of `attention_head_dim` as `down_block_types`. "
236
+ f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
237
+ )
238
+
239
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
240
+ raise ValueError(
241
+ "Must provide the same number of `cross_attention_dim` as `down_block_types`. "
242
+ f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
243
+ )
244
+
245
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
246
+ raise ValueError(
247
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. "
248
+ f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
249
+ )
250
+
251
+ # input
252
+ conv_in_padding = (conv_in_kernel - 1) // 2
253
+ self.conv_in = nn.Conv2d(
254
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
255
+ )
256
+
257
+ # time and guidance embeddings
258
+ embedding_types = {'time': time_embedding_type, 'guidance': guidance_embedding_type}
259
+ embedding_dims = {'time': time_embedding_dim, 'guidance': guidance_embedding_dim}
260
+ embed_dims, embed_input_dims, embed_projs = {}, {}, {}
261
+
262
+ for key in ['time', 'guidance']:
263
+ logger.info(f"Using {embedding_types[key]} embedding for {key}.")
264
+
265
+ if embedding_types[key] == "fourier":
266
+ embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
267
+ embed_input_dims[key] = embed_dims[key]
268
+ if embed_dims[key] % 2 != 0:
269
+ raise ValueError(
270
+ f"`{key}_embed_dim` should be divisible by 2, but is {embed_dims[key]}."
271
+ )
272
+ embed_projs[key] = GaussianFourierProjection(
273
+ embed_dims[key] // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
274
+ )
275
+
276
+ elif embedding_types[key] == "positional":
277
+ embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
278
+ embed_input_dims[key] = block_out_channels[0]
279
+ embed_projs[key] = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
280
+
281
+ else:
282
+ raise ValueError(
283
+ f"{embedding_types[key]} does not exist for {key} embedding. "
284
+ f"Please make sure to use one of `fourier` or `positional`."
285
+ )
286
+
287
+ self.time_proj, self.guidance_proj = embed_projs['time'], embed_projs['guidance']
288
+
289
+ self.time_embedding = TimestepEmbedding(
290
+ embed_input_dims['time'],
291
+ embed_dims['time'],
292
+ act_fn=act_fn,
293
+ post_act_fn=timestep_post_act,
294
+ cond_proj_dim=time_cond_proj_dim,
295
+ )
296
+ self.guidance_embedding = TimestepEmbedding(
297
+ embed_input_dims['guidance'],
298
+ embed_dims['guidance'],
299
+ act_fn=act_fn,
300
+ post_act_fn=guidance_post_act,
301
+ cond_proj_dim=guidance_cond_proj_dim,
302
+ )
303
+
304
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
305
+ encoder_hid_dim_type = "text_proj"
306
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
307
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
308
+
309
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
310
+ raise ValueError(
311
+ "`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` "
312
+ f"is set to {encoder_hid_dim_type}."
313
+ )
314
+
315
+ if encoder_hid_dim_type == "text_proj":
316
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
317
+ elif encoder_hid_dim_type == "text_image_proj":
318
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much,
319
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the
320
+ # currently only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
321
+ self.encoder_hid_proj = TextImageProjection(
322
+ text_embed_dim=encoder_hid_dim,
323
+ image_embed_dim=cross_attention_dim,
324
+ cross_attention_dim=cross_attention_dim,
325
+ )
326
+
327
+ elif encoder_hid_dim_type is not None:
328
+ raise ValueError(
329
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
330
+ )
331
+ else:
332
+ self.encoder_hid_proj = None
333
+
334
+ # class embedding
335
+ # print(f"class_embed_type: {class_embed_type}, num_class_embeds: {num_class_embeds}")
336
+ if class_embed_type is None and num_class_embeds is not None:
337
+ self.class_embedding = nn.Embedding(num_class_embeds, embedding_dims['time'])
338
+ elif class_embed_type == "timestep":
339
+ self.class_embedding = TimestepEmbedding(
340
+ embed_input_dims['time'], embed_dims['time'], act_fn=act_fn
341
+ )
342
+ elif class_embed_type == "identity":
343
+ self.class_embedding = nn.Identity(embed_dims['time'], embed_dims['time'])
344
+ elif class_embed_type == "projection":
345
+ if projection_class_embeddings_input_dim is None:
346
+ raise ValueError(
347
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
348
+ )
349
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
350
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
351
+ # 2. it projects from an arbitrary input dimension.
352
+ #
353
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
354
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal
355
+ # embeddings. As a result, `TimestepEmbedding` can be passed arbitrary vectors.
356
+ self.class_embedding = TimestepEmbedding(
357
+ projection_class_embeddings_input_dim, embed_dims['time']
358
+ )
359
+ elif class_embed_type == "simple_projection":
360
+ if projection_class_embeddings_input_dim is None:
361
+ raise ValueError(
362
+ "`class_embed_type`: 'simple_projection' requires "
363
+ "`projection_class_embeddings_input_dim` be set"
364
+ )
365
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, embed_dims['time'])
366
+ else:
367
+ self.class_embedding = None
368
+
369
+ # Addition embedding
370
+ if addition_embed_type == "text":
371
+ if encoder_hid_dim is not None:
372
+ text_time_embedding_from_dim = encoder_hid_dim
373
+ else:
374
+ text_time_embedding_from_dim = cross_attention_dim
375
+
376
+ self.add_embedding = TextTimeEmbedding(
377
+ text_time_embedding_from_dim, embed_dims['time'], num_heads=addition_embed_type_num_heads
378
+ )
379
+ elif addition_embed_type == "text_image":
380
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
381
+ # To not clutter the __init__ too much, they are set to `cross_attention_dim`
382
+ # here as this is exactly the required dimension for the currently only use case
383
+ # when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
384
+ self.add_embedding = TextImageTimeEmbedding(
385
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim,
386
+ time_embed_dim=embed_dims['time']
387
+ )
388
+ elif addition_embed_type is not None:
389
+ raise ValueError(
390
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
391
+ )
392
+
393
+ # Embedding activation function
394
+ if time_embedding_act_fn is None:
395
+ self.time_embed_act = None
396
+ else:
397
+ self.time_embed_act = get_activation(time_embedding_act_fn)
398
+
399
+ self.down_blocks = nn.ModuleList([])
400
+ self.up_blocks = nn.ModuleList([])
401
+
402
+ if isinstance(only_cross_attention, bool):
403
+ if mid_block_only_cross_attention is None:
404
+ mid_block_only_cross_attention = only_cross_attention
405
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
406
+
407
+ if mid_block_only_cross_attention is None:
408
+ mid_block_only_cross_attention = False
409
+
410
+ if isinstance(num_attention_heads, int):
411
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
412
+
413
+ if isinstance(attention_head_dim, int):
414
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
415
+
416
+ if isinstance(cross_attention_dim, int):
417
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
418
+
419
+ if isinstance(layers_per_block, int):
420
+ layers_per_block = [layers_per_block] * len(down_block_types)
421
+
422
+ if class_embeddings_concat:
423
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
424
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of
425
+ # the regular time embeddings
426
+ # Now we have time emb, guidance emb, and class emb
427
+ blocks_time_embed_dim = embed_dims['time'] * 3
428
+ else:
429
+ blocks_time_embed_dim = embed_dims['time']
430
+
431
+ # down
432
+ output_channel = block_out_channels[0]
433
+ for i, down_block_type in enumerate(down_block_types):
434
+ input_channel = output_channel
435
+ output_channel = block_out_channels[i]
436
+ is_final_block = i == len(block_out_channels) - 1
437
+
438
+ down_block = get_down_block(
439
+ down_block_type,
440
+ num_layers=layers_per_block[i],
441
+ in_channels=input_channel,
442
+ out_channels=output_channel,
443
+ temb_channels=blocks_time_embed_dim,
444
+ add_downsample=not is_final_block,
445
+ resnet_eps=norm_eps,
446
+ resnet_act_fn=act_fn,
447
+ resnet_groups=norm_num_groups,
448
+ cross_attention_dim=cross_attention_dim[i],
449
+ num_attention_heads=num_attention_heads[i],
450
+ downsample_padding=downsample_padding,
451
+ dual_cross_attention=dual_cross_attention,
452
+ use_linear_projection=use_linear_projection,
453
+ only_cross_attention=only_cross_attention[i],
454
+ upcast_attention=upcast_attention,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ resnet_skip_time_act=resnet_skip_time_act,
457
+ resnet_out_scale_factor=resnet_out_scale_factor,
458
+ cross_attention_norm=cross_attention_norm,
459
+ attention_head_dim=\
460
+ attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
461
+ )
462
+ self.down_blocks.append(down_block)
463
+
464
+ # mid
465
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
466
+ self.mid_block = UNetMidBlock2DCrossAttn(
467
+ in_channels=block_out_channels[-1],
468
+ temb_channels=blocks_time_embed_dim,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ output_scale_factor=mid_block_scale_factor,
472
+ resnet_time_scale_shift=resnet_time_scale_shift,
473
+ cross_attention_dim=cross_attention_dim[-1],
474
+ num_attention_heads=num_attention_heads[-1],
475
+ resnet_groups=norm_num_groups,
476
+ dual_cross_attention=dual_cross_attention,
477
+ use_linear_projection=use_linear_projection,
478
+ upcast_attention=upcast_attention,
479
+ )
480
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
481
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
482
+ in_channels=block_out_channels[-1],
483
+ temb_channels=blocks_time_embed_dim,
484
+ resnet_eps=norm_eps,
485
+ resnet_act_fn=act_fn,
486
+ output_scale_factor=mid_block_scale_factor,
487
+ cross_attention_dim=cross_attention_dim[-1],
488
+ attention_head_dim=attention_head_dim[-1],
489
+ resnet_groups=norm_num_groups,
490
+ resnet_time_scale_shift=resnet_time_scale_shift,
491
+ skip_time_act=resnet_skip_time_act,
492
+ only_cross_attention=mid_block_only_cross_attention,
493
+ cross_attention_norm=cross_attention_norm,
494
+ )
495
+ elif mid_block_type is None:
496
+ self.mid_block = None
497
+ else:
498
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
499
+
500
+ # count how many layers upsample the images
501
+ self.num_upsamplers = 0
502
+
503
+ # up
504
+ reversed_block_out_channels = list(reversed(block_out_channels))
505
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
506
+ reversed_layers_per_block = list(reversed(layers_per_block))
507
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
508
+ only_cross_attention = list(reversed(only_cross_attention))
509
+
510
+ output_channel = reversed_block_out_channels[0]
511
+ for i, up_block_type in enumerate(up_block_types):
512
+ is_final_block = i == len(block_out_channels) - 1
513
+
514
+ prev_output_channel = output_channel
515
+ output_channel = reversed_block_out_channels[i]
516
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
517
+
518
+ # add upsample block for all BUT final layer
519
+ if not is_final_block:
520
+ add_upsample = True
521
+ self.num_upsamplers += 1
522
+ else:
523
+ add_upsample = False
524
+
525
+ up_block = get_up_block(
526
+ up_block_type,
527
+ num_layers=reversed_layers_per_block[i] + 1,
528
+ in_channels=input_channel,
529
+ out_channels=output_channel,
530
+ prev_output_channel=prev_output_channel,
531
+ temb_channels=blocks_time_embed_dim,
532
+ add_upsample=add_upsample,
533
+ resnet_eps=norm_eps,
534
+ resnet_act_fn=act_fn,
535
+ resnet_groups=norm_num_groups,
536
+ cross_attention_dim=reversed_cross_attention_dim[i],
537
+ num_attention_heads=reversed_num_attention_heads[i],
538
+ dual_cross_attention=dual_cross_attention,
539
+ use_linear_projection=use_linear_projection,
540
+ only_cross_attention=only_cross_attention[i],
541
+ upcast_attention=upcast_attention,
542
+ resnet_time_scale_shift=resnet_time_scale_shift,
543
+ resnet_skip_time_act=resnet_skip_time_act,
544
+ resnet_out_scale_factor=resnet_out_scale_factor,
545
+ cross_attention_norm=cross_attention_norm,
546
+ attention_head_dim=\
547
+ attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
548
+ )
549
+ self.up_blocks.append(up_block)
550
+ prev_output_channel = output_channel
551
+
552
+ # out
553
+ if norm_num_groups is not None:
554
+ self.conv_norm_out = nn.GroupNorm(
555
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
556
+ )
557
+ self.conv_act = get_activation(act_fn)
558
+
559
+ else:
560
+ self.conv_norm_out = None
561
+ self.conv_act = None
562
+
563
+ conv_out_padding = (conv_out_kernel - 1) // 2
564
+ self.conv_out = nn.Conv2d(
565
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
566
+ )
567
+
568
+ @property
569
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
570
+ r"""
571
+ Returns:
572
+ `dict` of attention processors: A dictionary containing all attention processors used in
573
+ the model with indexed by its weight name.
574
+ """
575
+ # set recursively
576
+ processors = {}
577
+
578
+ def fn_recursive_add_processors(
579
+ name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]
580
+ ):
581
+ if hasattr(module, "set_processor"):
582
+ processors[f"{name}.processor"] = module.processor
583
+
584
+ for sub_name, child in module.named_children():
585
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
586
+
587
+ return processors
588
+
589
+ for name, module in self.named_children():
590
+ fn_recursive_add_processors(name, module, processors)
591
+
592
+ return processors
593
+
594
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
595
+ r"""
596
+ Parameters:
597
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
598
+ The instantiated processor class or a dictionary of processor classes that will be set as the
599
+ processor of **all** `Attention` layers.
600
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross
601
+ attention processor. This is strongly recommended when setting trainable attention processors.
602
+ """
603
+ count = len(self.attn_processors.keys())
604
+
605
+ if isinstance(processor, dict) and len(processor) != count:
606
+ raise ValueError(
607
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match"
608
+ f" the number of attention layers: {count}. Please make sure to pass {count} processor classes."
609
+ )
610
+
611
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
612
+ if hasattr(module, "set_processor"):
613
+ if not isinstance(processor, dict):
614
+ module.set_processor(processor)
615
+ else:
616
+ module.set_processor(processor.pop(f"{name}.processor"))
617
+
618
+ for sub_name, child in module.named_children():
619
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
620
+
621
+ for name, module in self.named_children():
622
+ fn_recursive_attn_processor(name, module, processor)
623
+
624
+ def set_default_attn_processor(self):
625
+ """
626
+ Disables custom attention processors and sets the default attention implementation.
627
+ """
628
+ self.set_attn_processor(AttnProcessor())
629
+
630
+ def set_attention_slice(self, slice_size):
631
+ r"""
632
+ Enable sliced attention computation.
633
+
634
+ When this option is enabled, the attention module will split the input tensor in slices, to compute
635
+ attention in several steps. This is useful to save some memory in exchange for a small speed decrease.
636
+
637
+ Args:
638
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
639
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two
640
+ steps. "max"`, maximum amount of memory will be saved by running only one slice at a time.
641
+ If a number is provided, uses as many slices as `num_attention_heads // slice_size`.
642
+ In this case, `num_attention_heads` must be a multiple of `slice_size`.
643
+ """
644
+ sliceable_head_dims = []
645
+
646
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
647
+ if hasattr(module, "set_attention_slice"):
648
+ sliceable_head_dims.append(module.sliceable_head_dim)
649
+
650
+ for child in module.children():
651
+ fn_recursive_retrieve_sliceable_dims(child)
652
+
653
+ # retrieve number of attention layers
654
+ for module in self.children():
655
+ fn_recursive_retrieve_sliceable_dims(module)
656
+
657
+ num_sliceable_layers = len(sliceable_head_dims)
658
+
659
+ if slice_size == "auto":
660
+ # half the attention head size is usually a good trade-off between
661
+ # speed and memory
662
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
663
+ elif slice_size == "max":
664
+ # make smallest slice possible
665
+ slice_size = num_sliceable_layers * [1]
666
+
667
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
668
+
669
+ if len(slice_size) != len(sliceable_head_dims):
670
+ raise ValueError(
671
+ f"You have provided {len(slice_size)}, but {self.config} has "
672
+ f"{len(sliceable_head_dims)} different attention layers. "
673
+ f"Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
674
+ )
675
+
676
+ for i in range(len(slice_size)):
677
+ size = slice_size[i]
678
+ dim = sliceable_head_dims[i]
679
+ if size is not None and size > dim:
680
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
681
+
682
+ # Recursively walk through all the children.
683
+ # Any children which exposes the set_attention_slice method
684
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
685
+ if hasattr(module, "set_attention_slice"):
686
+ module.set_attention_slice(slice_size.pop())
687
+
688
+ for child in module.children():
689
+ fn_recursive_set_attention_slice(child, slice_size)
690
+
691
+ reversed_slice_size = list(reversed(slice_size))
692
+ for module in self.children():
693
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
694
+
695
+ def _set_gradient_checkpointing(self, module, value=False):
696
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
697
+ module.gradient_checkpointing = value
698
+
699
+ def _prepare_tensor(self, value, device):
700
+ if not torch.is_tensor(value):
701
+ # Requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
702
+ # This would be a good case for the `match` statement (Python 3.10+)
703
+ if isinstance(value, float):
704
+ dtype = torch.float32 if device.type == "mps" else torch.float64
705
+ else:
706
+ dtype = torch.int32 if device.type == "mps" else torch.int64
707
+
708
+ return torch.tensor([value], dtype=dtype, device=device)
709
+
710
+ elif len(value.shape) == 0:
711
+ return value[None].to(device)
712
+
713
+ else:
714
+ return value
715
+
716
+ def forward(
717
+ self,
718
+ sample: torch.FloatTensor,
719
+ timestep: Union[torch.Tensor, float, int],
720
+ guidance: Union[torch.Tensor, float, int],
721
+ encoder_hidden_states: torch.Tensor,
722
+ class_labels: Optional[torch.Tensor] = None,
723
+ timestep_cond: Optional[torch.Tensor] = None,
724
+ guidance_cond: Optional[torch.Tensor] = None,
725
+ attention_mask: Optional[torch.Tensor] = None,
726
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
727
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
728
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
729
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
730
+ encoder_attention_mask: Optional[torch.Tensor] = None,
731
+ return_dict: bool = True,
732
+ **kwargs
733
+ ) -> Union[UNet2DConditionOutput, Tuple]:
734
+ r"""
735
+ Args:
736
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
737
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
738
+ encoder_hidden_states (`torch.FloatTensor`):
739
+ (batch, sequence_length, feature_dim) encoder hidden states
740
+ encoder_attention_mask (`torch.Tensor`):
741
+ (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep,
742
+ False = discard. Mask will be converted into a bias, which adds large negative values to
743
+ attention scores corresponding to "discard" tokens.
744
+ return_dict (`bool`, *optional*, defaults to `True`):
745
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`]
746
+ instead of a plain tuple.
747
+ cross_attention_kwargs (`dict`, *optional*):
748
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined
749
+ under `self.processor` in [diffusers.cross_attention]
750
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
751
+ added_cond_kwargs (`dict`, *optional*):
752
+ A kwargs dictionary that if specified includes additonal conditions that can be used for
753
+ additonal time embeddings or encoder hidden states projections. See the configurations
754
+ `encoder_hid_dim_type` and `addition_embed_type` for more information.
755
+
756
+ Returns:
757
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
758
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
759
+ When returning a tuple, the first element is the sample tensor.
760
+ """
761
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
762
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
763
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
764
+ # on the fly if necessary.
765
+ default_overall_up_factor = 2 ** self.num_upsamplers
766
+
767
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
768
+ forward_upsample_size = False
769
+ upsample_size = None
770
+
771
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
772
+ logger.info("Forward upsample size to force interpolation output size.")
773
+ forward_upsample_size = True
774
+
775
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
776
+ # expects mask of shape:
777
+ # [batch, key_tokens]
778
+ # adds singleton query_tokens dimension:
779
+ # [batch, 1, key_tokens]
780
+ # this helps to broadcast it as a bias over attention scores,
781
+ # which will be in one of the following shapes:
782
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
783
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
784
+ if attention_mask is not None:
785
+ # assume that mask is expressed as:
786
+ # (1 = keep, 0 = discard)
787
+ # convert mask into a bias that can be added to attention scores:
788
+ # (keep = +0, discard = -10000.0)
789
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
790
+ attention_mask = attention_mask.unsqueeze(1)
791
+
792
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
793
+ if encoder_attention_mask is not None:
794
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * (-10000.0)
795
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
796
+
797
+ # 0. center input if necessary
798
+ if self.config.center_input_sample:
799
+ sample = 2 * sample - 1.0
800
+
801
+ # 1. time and guidance
802
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
803
+ timestep = self._prepare_tensor(timestep, sample.device).expand(sample.shape[0])
804
+ # Project to get embedding
805
+ # `Timestep` does not contain any weights and will always return fp32 tensors
806
+ # but time_embedding might actually be running in fp16. so we need to cast here.
807
+ t_emb = self.time_proj(timestep).to(dtype=sample.dtype)
808
+ t_emb = self.time_embedding(t_emb, timestep_cond)
809
+
810
+ guidance = self._prepare_tensor(guidance, sample.device).expand(sample.shape[0])
811
+ g_emb = self.guidance_proj(guidance).to(dtype=sample.dtype)
812
+ g_emb = self.guidance_embedding(g_emb, guidance_cond)
813
+
814
+ # 1.5. prepare other embeddings
815
+ if self.class_embedding is None:
816
+ emb = t_emb + g_emb
817
+ else:
818
+ if class_labels is None:
819
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
820
+ if self.config.class_embed_type == "timestep":
821
+ class_labels = self.time_proj(class_labels).to(dtype=sample.dtype)
822
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
823
+
824
+ if self.config.class_embeddings_concat:
825
+ emb = torch.cat([t_emb, g_emb, class_emb], dim=-1)
826
+ else:
827
+ emb = t_emb + g_emb + class_emb
828
+
829
+ if self.config.addition_embed_type == "text":
830
+ aug_emb = self.add_embedding(encoder_hidden_states)
831
+ emb = emb + aug_emb
832
+ elif self.config.addition_embed_type == "text_image":
833
+ # Kadinsky 2.1 - style
834
+ if "image_embeds" not in added_cond_kwargs:
835
+ raise ValueError(
836
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
837
+ "which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
838
+ )
839
+
840
+ image_embs = added_cond_kwargs.get("image_embeds")
841
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
842
+
843
+ aug_emb = self.add_embedding(text_embs, image_embs)
844
+ emb = emb + aug_emb
845
+
846
+ if self.time_embed_act is not None:
847
+ emb = self.time_embed_act(emb)
848
+
849
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
850
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
851
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
852
+ # Kadinsky 2.1 - style
853
+ if "image_embeds" not in added_cond_kwargs:
854
+ raise ValueError(
855
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
856
+ "which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
857
+ )
858
+
859
+ image_embeds = added_cond_kwargs.get("image_embeds")
860
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
861
+
862
+ # 2. pre-process
863
+ sample = self.conv_in(sample)
864
+
865
+ # 3. down
866
+ down_block_res_samples = (sample,)
867
+ for downsample_block in self.down_blocks:
868
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
869
+ sample, res_samples = downsample_block(
870
+ hidden_states=sample,
871
+ temb=emb,
872
+ encoder_hidden_states=encoder_hidden_states,
873
+ attention_mask=attention_mask,
874
+ cross_attention_kwargs=cross_attention_kwargs,
875
+ encoder_attention_mask=encoder_attention_mask,
876
+ )
877
+ else:
878
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
879
+
880
+ down_block_res_samples += res_samples
881
+
882
+ if down_block_additional_residuals is not None:
883
+ new_down_block_res_samples = ()
884
+
885
+ for down_block_res_sample, down_block_additional_residual in zip(
886
+ down_block_res_samples, down_block_additional_residuals
887
+ ):
888
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
889
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
890
+
891
+ down_block_res_samples = new_down_block_res_samples
892
+
893
+ # 4. mid
894
+ if self.mid_block is not None:
895
+ sample = self.mid_block(
896
+ sample,
897
+ emb,
898
+ encoder_hidden_states=encoder_hidden_states,
899
+ attention_mask=attention_mask,
900
+ cross_attention_kwargs=cross_attention_kwargs,
901
+ encoder_attention_mask=encoder_attention_mask,
902
+ )
903
+
904
+ if mid_block_additional_residual is not None:
905
+ sample = sample + mid_block_additional_residual
906
+
907
+ # 5. up
908
+ for i, upsample_block in enumerate(self.up_blocks):
909
+ is_final_block = i == len(self.up_blocks) - 1
910
+
911
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
912
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
913
+
914
+ # if we have not reached the final block and need to forward the
915
+ # upsample size, we do it here
916
+ if not is_final_block and forward_upsample_size:
917
+ upsample_size = down_block_res_samples[-1].shape[2:]
918
+
919
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
920
+ sample = upsample_block(
921
+ hidden_states=sample,
922
+ temb=emb,
923
+ res_hidden_states_tuple=res_samples,
924
+ encoder_hidden_states=encoder_hidden_states,
925
+ cross_attention_kwargs=cross_attention_kwargs,
926
+ upsample_size=upsample_size,
927
+ attention_mask=attention_mask,
928
+ encoder_attention_mask=encoder_attention_mask,
929
+ )
930
+ else:
931
+ sample = upsample_block(
932
+ hidden_states=sample, temb=emb,
933
+ res_hidden_states_tuple=res_samples, upsample_size=upsample_size
934
+ )
935
+
936
+ # 6. post-process
937
+ if self.conv_norm_out:
938
+ sample = self.conv_norm_out(sample)
939
+ sample = self.conv_act(sample)
940
+ sample = self.conv_out(sample)
941
+
942
+ if not return_dict:
943
+ return (sample,)
944
+
945
+ return UNet2DConditionOutput(sample=sample)
diffusers/scheduling_heun_discrete.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ ### This file has been modified for the purposes of the ConsistencyTTA generation. ###
16
+
17
+ import math
18
+ from typing import List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from .utils.configuration_utils import ConfigMixin, register_to_config
24
+ from .utils.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
25
+
26
+
27
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
28
+ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
29
+ """
30
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines
31
+ the cumulative product of (1-beta) over time from t = [0,1].
32
+
33
+ Contains a function alpha_bar that takes an argument t and transforms it to the
34
+ cumulative product of (1-beta) up to that part of the diffusion process.
35
+
36
+
37
+ Args:
38
+ num_diffusion_timesteps (`int`): the number of betas to produce.
39
+ max_beta (`float`):
40
+ the maximum beta to use; use values lower than 1 to prevent singularities.
41
+
42
+ Returns:
43
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
44
+ """
45
+
46
+ def alpha_bar(time_step):
47
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
48
+
49
+ betas = []
50
+ for i in range(num_diffusion_timesteps):
51
+ t1 = i / num_diffusion_timesteps
52
+ t2 = (i + 1) / num_diffusion_timesteps
53
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
54
+ return torch.tensor(betas, dtype=torch.float32)
55
+
56
+
57
+ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
58
+ """
59
+ Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules.
60
+ Based on the original k-diffusion implementation by Katherine Crowson:
61
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/
62
+ k_diffusion/sampling.py#L90
63
+
64
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed
65
+ in the scheduler's `__init__` function, such as `num_train_timesteps`.
66
+ They can be accessed via `scheduler.config.num_train_timesteps`.
67
+ [`SchedulerMixin`] provides general loading and saving functionality via the
68
+ [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
69
+
70
+ Args:
71
+ num_train_timesteps (`int`):
72
+ number of diffusion steps used to train the model.
73
+ beta_start (`float`):
74
+ the starting `beta` value of inference.
75
+ beta_end (`float`):
76
+ the final `beta` value.
77
+ beta_schedule (`str`):
78
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping
79
+ the model. Choose from `linear` or `scaled_linear`.
80
+ trained_betas (`np.ndarray`, optional):
81
+ option to pass an array of betas directly to the constructor to bypass
82
+ `beta_start`, `beta_end` etc.
83
+ options to clip the variance used when adding noise to the denoised sample.
84
+ Choose from `fixed_small`, `fixed_small_log`, `fixed_large`,
85
+ `fixed_large_log`, `learned` or `learned_range`.
86
+ prediction_type (`str`, default `epsilon`, optional):
87
+ prediction type of the scheduler function, one of
88
+ `epsilon` (predicting the noise of the diffusion process),
89
+ `sample` (directly predicting the noisy sample`), or
90
+ `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
91
+ """
92
+
93
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
94
+ order = 2
95
+
96
+ @register_to_config
97
+ def __init__(
98
+ self,
99
+ num_train_timesteps: int = 1000,
100
+ beta_start: float = 0.00085, # sensible defaults
101
+ beta_end: float = 0.012,
102
+ beta_schedule: str = "linear",
103
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
104
+ prediction_type: str = "epsilon",
105
+ use_karras_sigmas: Optional[bool] = False,
106
+ ):
107
+ if trained_betas is not None:
108
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
109
+ elif beta_schedule == "linear":
110
+ self.betas = torch.linspace(
111
+ beta_start, beta_end, num_train_timesteps, dtype=torch.float32
112
+ )
113
+ elif beta_schedule == "scaled_linear":
114
+ # this schedule is very specific to the latent diffusion model.
115
+ self.betas = (
116
+ torch.linspace(
117
+ beta_start ** 0.5, beta_end ** 0.5,
118
+ num_train_timesteps, dtype=torch.float32
119
+ ) ** 2
120
+ )
121
+ elif beta_schedule == "squaredcos_cap_v2":
122
+ # Glide cosine schedule
123
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
124
+ else:
125
+ raise NotImplementedError(
126
+ f"{beta_schedule} does is not implemented for {self.__class__}"
127
+ )
128
+
129
+ self.alphas = 1.0 - self.betas
130
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
131
+
132
+ # set all values
133
+ self.use_karras_sigmas = use_karras_sigmas
134
+ self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
135
+
136
+ def index_for_timestep(self, timestep):
137
+ """Get the first / last index at which self.timesteps == timestep
138
+ """
139
+ assert len(timestep.shape) < 2
140
+ avail_timesteps = self.timesteps.reshape(1, -1).to(timestep.device)
141
+ mask = (avail_timesteps == timestep.reshape(-1, 1))
142
+ assert (mask.sum(dim=1) != 0).all(), f"timestep: {timestep.tolist()}"
143
+ mask = mask.cpu() * torch.arange(mask.shape[1]).reshape(1, -1)
144
+
145
+ if self.state_in_first_order:
146
+ return mask.argmax(dim=1).numpy()
147
+ else:
148
+ return mask.argmax(dim=1).numpy() - 1
149
+
150
+ def scale_model_input(
151
+ self,
152
+ sample: torch.FloatTensor,
153
+ timestep: Union[float, torch.FloatTensor],
154
+ ) -> torch.FloatTensor:
155
+ """
156
+ Ensures interchangeability with schedulers that need to scale the
157
+ denoising model input depending on the current timestep.
158
+ Args:
159
+ sample (`torch.FloatTensor`): input sample
160
+ timestep (`int`, optional): current timestep
161
+ Returns:
162
+ `torch.FloatTensor`: scaled input sample
163
+ """
164
+ if not torch.is_tensor(timestep):
165
+ timestep = torch.tensor(timestep)
166
+ timestep = timestep.to(sample.device).reshape(-1)
167
+ step_index = self.index_for_timestep(timestep)
168
+
169
+ sigma = self.sigmas[step_index].reshape(-1, 1, 1, 1).to(sample.device)
170
+ sample = sample / ((sigma ** 2 + 1) ** 0.5) # sample *= sqrt_alpha_prod
171
+ return sample
172
+
173
+ def set_timesteps(
174
+ self,
175
+ num_inference_steps: int,
176
+ device: Union[str, torch.device] = None,
177
+ num_train_timesteps: Optional[int] = None,
178
+ ):
179
+ """
180
+ Sets the timesteps used for the diffusion chain.
181
+ Supporting function to be run before inference.
182
+
183
+ Args:
184
+ num_inference_steps (`int`):
185
+ the number of diffusion steps used when generating samples
186
+ with a pre-trained model.
187
+ device (`str` or `torch.device`, optional):
188
+ the device to which the timesteps should be moved to.
189
+ If `None`, the timesteps are not moved.
190
+ """
191
+ self.num_inference_steps = num_inference_steps
192
+ num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
193
+
194
+ timesteps = np.linspace(
195
+ 0, num_train_timesteps - 1, num_inference_steps, dtype=float
196
+ )[::-1].copy()
197
+
198
+ # sigma^2 = beta / alpha
199
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
200
+ log_sigmas = np.log(sigmas)
201
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
202
+
203
+ if self.use_karras_sigmas:
204
+ sigmas = self._convert_to_karras(
205
+ in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
206
+ )
207
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
208
+
209
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
210
+ sigmas = torch.from_numpy(sigmas).to(device=device)
211
+ self.sigmas = torch.cat(
212
+ [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
213
+ )
214
+
215
+ # standard deviation of the initial noise distribution
216
+ self.init_noise_sigma = self.sigmas.max()
217
+
218
+ timesteps = torch.from_numpy(timesteps)
219
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
220
+ if 'mps' in str(device):
221
+ timesteps = timesteps.float()
222
+ self.timesteps = timesteps.to(device)
223
+
224
+ # empty dt and derivative
225
+ self.prev_derivative = None
226
+ self.dt = None
227
+
228
+ def _sigma_to_t(self, sigma, log_sigmas):
229
+ # get log sigma
230
+ log_sigma = np.log(sigma)
231
+
232
+ # get distribution
233
+ dists = log_sigma - log_sigmas[:, np.newaxis]
234
+
235
+ # get sigmas range
236
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(
237
+ max=log_sigmas.shape[0] - 2
238
+ )
239
+ high_idx = low_idx + 1
240
+
241
+ low = log_sigmas[low_idx]
242
+ high = log_sigmas[high_idx]
243
+
244
+ # interpolate sigmas
245
+ w = (low - log_sigma) / (low - high)
246
+ w = np.clip(w, 0, 1)
247
+
248
+ # transform interpolation to time range
249
+ t = (1 - w) * low_idx + w * high_idx
250
+ t = t.reshape(sigma.shape)
251
+ return t
252
+
253
+ def _convert_to_karras(
254
+ self, in_sigmas: torch.FloatTensor, num_inference_steps
255
+ ) -> torch.FloatTensor:
256
+ """Constructs the noise schedule of Karras et al. (2022)."""
257
+
258
+ sigma_min: float = in_sigmas[-1].item()
259
+ sigma_max: float = in_sigmas[0].item()
260
+
261
+ rho = 7.0 # 7.0 is the value used in the paper
262
+ ramp = np.linspace(0, 1, num_inference_steps)
263
+ min_inv_rho = sigma_min ** (1 / rho)
264
+ max_inv_rho = sigma_max ** (1 / rho)
265
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
266
+ return sigmas
267
+
268
+ @property
269
+ def state_in_first_order(self):
270
+ return self.dt is None
271
+
272
+ def step(
273
+ self,
274
+ model_output: Union[torch.FloatTensor, np.ndarray],
275
+ timestep: Union[float, torch.FloatTensor],
276
+ sample: Union[torch.FloatTensor, np.ndarray],
277
+ return_dict: bool = True,
278
+ ) -> Union[SchedulerOutput, Tuple]:
279
+ """
280
+ Predict the sample at the previous timestep by reversing the SDE.
281
+ Core function to propagate the diffusion process from the learned
282
+ model outputs (most often the predicted noise).
283
+ Args:
284
+ model_output (`torch.FloatTensor` or `np.ndarray`):
285
+ direct output from learned diffusion model.
286
+ timestep (`int`):
287
+ current discrete timestep in the diffusion chain.
288
+ sample (`torch.FloatTensor` or `np.ndarray`):
289
+ current instance of sample being created by diffusion process.
290
+ return_dict (`bool`):
291
+ option for returning tuple rather than SchedulerOutput class
292
+ Returns:
293
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
294
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict`
295
+ is True, otherwise a `tuple`. When returning a tuple,
296
+ the first element is the sample tensor.
297
+ """
298
+ if not torch.is_tensor(timestep):
299
+ timestep = torch.tensor(timestep)
300
+ timestep = timestep.reshape(-1).to(sample.device)
301
+ step_index = self.index_for_timestep(timestep)
302
+
303
+ if self.state_in_first_order:
304
+ sigma = self.sigmas[step_index]
305
+ sigma_next = self.sigmas[step_index + 1]
306
+ else:
307
+ # 2nd order / Heun's method
308
+ sigma = self.sigmas[step_index - 1]
309
+ sigma_next = self.sigmas[step_index]
310
+
311
+ sigma = sigma.reshape(-1, 1, 1, 1).to(sample.device)
312
+ sigma_next = sigma_next.reshape(-1, 1, 1, 1).to(sample.device)
313
+ sigma_input = sigma if self.state_in_first_order else sigma_next
314
+
315
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
316
+ if self.config.prediction_type == "epsilon":
317
+ pred_original_sample = sample - sigma_input * model_output
318
+ elif self.config.prediction_type == "v_prediction":
319
+ alpha_prod = 1 / (sigma_input ** 2 + 1)
320
+ pred_original_sample = (
321
+ sample * alpha_prod - model_output * (sigma_input * alpha_prod ** .5)
322
+ )
323
+ elif self.config.prediction_type == "sample":
324
+ raise NotImplementedError("prediction_type not implemented yet: sample")
325
+ else:
326
+ raise ValueError(
327
+ f"prediction_type given as {self.config.prediction_type} "
328
+ "must be one of `epsilon`, or `v_prediction`"
329
+ )
330
+
331
+ if self.state_in_first_order:
332
+ # 2. Convert to an ODE derivative for 1st order
333
+ derivative = (sample - pred_original_sample) / sigma
334
+ # 3. delta timestep
335
+ dt = sigma_next - sigma
336
+
337
+ # store for 2nd order step
338
+ self.prev_derivative = derivative
339
+ self.dt = dt
340
+ self.sample = sample
341
+ else:
342
+ # 2. 2nd order / Heun's method
343
+ derivative = (sample - pred_original_sample) / sigma_next
344
+ derivative = (self.prev_derivative + derivative) / 2
345
+
346
+ # 3. take prev timestep & sample
347
+ dt = self.dt
348
+ sample = self.sample
349
+
350
+ # free dt and derivative
351
+ # Note, this puts the scheduler in "first order mode"
352
+ self.prev_derivative = None
353
+ self.dt = None
354
+ self.sample = None
355
+
356
+ prev_sample = sample + derivative * dt
357
+
358
+ if not return_dict:
359
+ return (prev_sample,)
360
+
361
+ return SchedulerOutput(prev_sample=prev_sample)
362
+
363
+ def add_noise(
364
+ self,
365
+ original_samples: torch.FloatTensor,
366
+ noise: torch.FloatTensor,
367
+ timesteps: torch.FloatTensor,
368
+ ) -> torch.FloatTensor:
369
+
370
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
371
+ self.sigmas = self.sigmas.to(
372
+ device=original_samples.device, dtype=original_samples.dtype
373
+ )
374
+ self.timesteps = self.timesteps.to(original_samples.device)
375
+ timesteps = timesteps.to(original_samples.device)
376
+
377
+ step_indices = self.index_for_timestep(timesteps)
378
+
379
+ sigma = self.sigmas[step_indices].flatten()
380
+ while len(sigma.shape) < len(original_samples.shape):
381
+ sigma = sigma.unsqueeze(-1)
382
+
383
+ noisy_samples = original_samples + noise * sigma
384
+ return noisy_samples
385
+
386
+ def __len__(self):
387
+ return self.config.num_train_timesteps
diffusers/utils/configuration_utils.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixin base class and utilities."""
17
+
18
+ import dataclasses
19
+ import functools
20
+ import importlib
21
+ import inspect
22
+ import json
23
+ import os
24
+ import re
25
+ from collections import OrderedDict
26
+ from pathlib import PosixPath
27
+ from typing import Any, Dict, Tuple, Union
28
+
29
+ import numpy as np
30
+ from huggingface_hub import hf_hub_download
31
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
32
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
33
+ from requests import HTTPError
34
+
35
+ from .import_utils import DummyObject
36
+ from .deprecation_utils import deprecate
37
+ from .hub_utils import extract_commit_hash, http_user_agent
38
+ from .logging import get_logger
39
+ from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
45
+
46
+
47
+ class FrozenDict(OrderedDict):
48
+ def __init__(self, *args, **kwargs):
49
+ super().__init__(*args, **kwargs)
50
+
51
+ for key, value in self.items():
52
+ setattr(self, key, value)
53
+
54
+ self.__frozen = True
55
+
56
+ def __delitem__(self, *args, **kwargs):
57
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
58
+
59
+ def setdefault(self, *args, **kwargs):
60
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
61
+
62
+ def pop(self, *args, **kwargs):
63
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
64
+
65
+ def update(self, *args, **kwargs):
66
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
67
+
68
+ def __setattr__(self, name, value):
69
+ if hasattr(self, "__frozen") and self.__frozen:
70
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
71
+ super().__setattr__(name, value)
72
+
73
+ def __setitem__(self, name, value):
74
+ if hasattr(self, "__frozen") and self.__frozen:
75
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
76
+ super().__setitem__(name, value)
77
+
78
+
79
+ class ConfigMixin:
80
+ r"""
81
+ Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
82
+ provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
83
+ saving classes that inherit from [`ConfigMixin`].
84
+
85
+ Class attributes:
86
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
87
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
88
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
89
+ overridden by subclass).
90
+ - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
91
+ - **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
92
+ should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
93
+ subclass).
94
+ """
95
+ config_name = None
96
+ ignore_for_config = []
97
+ has_compatibles = False
98
+
99
+ _deprecated_kwargs = []
100
+
101
+ def register_to_config(self, **kwargs):
102
+ if self.config_name is None:
103
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
104
+ # Special case for `kwargs` used in deprecation warning added to schedulers
105
+ # TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
106
+ # or solve in a more general way.
107
+ kwargs.pop("kwargs", None)
108
+
109
+ if not hasattr(self, "_internal_dict"):
110
+ internal_dict = kwargs
111
+ else:
112
+ previous_dict = dict(self._internal_dict)
113
+ internal_dict = {**self._internal_dict, **kwargs}
114
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
115
+
116
+ self._internal_dict = FrozenDict(internal_dict)
117
+
118
+ def __getattr__(self, name: str) -> Any:
119
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
120
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
121
+
122
+ Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
123
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
124
+ """
125
+
126
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
127
+ is_attribute = name in self.__dict__
128
+
129
+ if is_in_config and not is_attribute:
130
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
131
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
132
+ return self._internal_dict[name]
133
+
134
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
135
+
136
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
137
+ """
138
+ Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
139
+ [`~ConfigMixin.from_config`] class method.
140
+
141
+ Args:
142
+ save_directory (`str` or `os.PathLike`):
143
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
144
+ """
145
+ if os.path.isfile(save_directory):
146
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
147
+
148
+ os.makedirs(save_directory, exist_ok=True)
149
+
150
+ # If we save using the predefined names, we can load using `from_config`
151
+ output_config_file = os.path.join(save_directory, self.config_name)
152
+
153
+ self.to_json_file(output_config_file)
154
+ logger.info(f"Configuration saved in {output_config_file}")
155
+
156
+ @classmethod
157
+ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
158
+ r"""
159
+ Instantiate a Python class from a config dictionary.
160
+
161
+ Parameters:
162
+ config (`Dict[str, Any]`):
163
+ A config dictionary from which the Python class is instantiated. Make sure to only load configuration
164
+ files of compatible classes.
165
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
166
+ Whether kwargs that are not consumed by the Python class should be returned or not.
167
+ kwargs (remaining dictionary of keyword arguments, *optional*):
168
+ Can be used to update the configuration object (after it is loaded) and initiate the Python class.
169
+ `**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
170
+ overwrite the same named arguments in `config`.
171
+
172
+ Returns:
173
+ [`ModelMixin`] or [`SchedulerMixin`]:
174
+ A model or scheduler object instantiated from a config dictionary.
175
+
176
+ Examples:
177
+
178
+ ```python
179
+ >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
180
+
181
+ >>> # Download scheduler from huggingface.co and cache.
182
+ >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
183
+
184
+ >>> # Instantiate DDIM scheduler class with same config as DDPM
185
+ >>> scheduler = DDIMScheduler.from_config(scheduler.config)
186
+
187
+ >>> # Instantiate PNDM scheduler class with same config as DDPM
188
+ >>> scheduler = PNDMScheduler.from_config(scheduler.config)
189
+ ```
190
+ """
191
+ # <===== TO BE REMOVED WITH DEPRECATION
192
+ # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
193
+ if "pretrained_model_name_or_path" in kwargs:
194
+ config = kwargs.pop("pretrained_model_name_or_path")
195
+
196
+ if config is None:
197
+ raise ValueError("Please make sure to provide a config as the first positional argument.")
198
+ # ======>
199
+
200
+ if not isinstance(config, dict):
201
+ deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
202
+ if "Scheduler" in cls.__name__:
203
+ deprecation_message += (
204
+ f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
205
+ " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
206
+ " be removed in v1.0.0."
207
+ )
208
+ elif "Model" in cls.__name__:
209
+ deprecation_message += (
210
+ f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
211
+ f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
212
+ " instead. This functionality will be removed in v1.0.0."
213
+ )
214
+ deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
215
+ config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
216
+
217
+ init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
218
+
219
+ # Allow dtype to be specified on initialization
220
+ if "dtype" in unused_kwargs:
221
+ init_dict["dtype"] = unused_kwargs.pop("dtype")
222
+
223
+ # add possible deprecated kwargs
224
+ for deprecated_kwarg in cls._deprecated_kwargs:
225
+ if deprecated_kwarg in unused_kwargs:
226
+ init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
227
+
228
+ # Return model and optionally state and/or unused_kwargs
229
+ model = cls(**init_dict)
230
+
231
+ # make sure to also save config parameters that might be used for compatible classes
232
+ model.register_to_config(**hidden_dict)
233
+
234
+ # add hidden kwargs of compatible classes to unused_kwargs
235
+ unused_kwargs = {**unused_kwargs, **hidden_dict}
236
+
237
+ if return_unused_kwargs:
238
+ return (model, unused_kwargs)
239
+ else:
240
+ return model
241
+
242
+ @classmethod
243
+ def get_config_dict(cls, *args, **kwargs):
244
+ deprecation_message = (
245
+ f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
246
+ " removed in version v1.0.0"
247
+ )
248
+ deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
249
+ return cls.load_config(*args, **kwargs)
250
+
251
+ @classmethod
252
+ def load_config(
253
+ cls,
254
+ pretrained_model_name_or_path: Union[str, os.PathLike],
255
+ return_unused_kwargs=False,
256
+ return_commit_hash=False,
257
+ **kwargs,
258
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
259
+ r"""
260
+ Load a model or scheduler configuration.
261
+
262
+ Parameters:
263
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
264
+ Can be either:
265
+
266
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
267
+ the Hub.
268
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
269
+ [`~ConfigMixin.save_config`].
270
+
271
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
272
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
273
+ is not used.
274
+ force_download (`bool`, *optional*, defaults to `False`):
275
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
276
+ cached versions if they exist.
277
+ resume_download (`bool`, *optional*, defaults to `False`):
278
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
279
+ incompletely downloaded files are deleted.
280
+ proxies (`Dict[str, str]`, *optional*):
281
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
282
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
283
+ output_loading_info(`bool`, *optional*, defaults to `False`):
284
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
285
+ local_files_only (`bool`, *optional*, defaults to `False`):
286
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
287
+ won't be downloaded from the Hub.
288
+ use_auth_token (`str` or *bool*, *optional*):
289
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
290
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
291
+ revision (`str`, *optional*, defaults to `"main"`):
292
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
293
+ allowed by Git.
294
+ subfolder (`str`, *optional*, defaults to `""`):
295
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
296
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
297
+ Whether unused keyword arguments of the config are returned.
298
+ return_commit_hash (`bool`, *optional*, defaults to `False):
299
+ Whether the `commit_hash` of the loaded configuration are returned.
300
+
301
+ Returns:
302
+ `dict`:
303
+ A dictionary of all the parameters stored in a JSON configuration file.
304
+
305
+ """
306
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
307
+ force_download = kwargs.pop("force_download", False)
308
+ resume_download = kwargs.pop("resume_download", False)
309
+ proxies = kwargs.pop("proxies", None)
310
+ use_auth_token = kwargs.pop("use_auth_token", None)
311
+ local_files_only = kwargs.pop("local_files_only", False)
312
+ revision = kwargs.pop("revision", None)
313
+ _ = kwargs.pop("mirror", None)
314
+ subfolder = kwargs.pop("subfolder", None)
315
+ user_agent = kwargs.pop("user_agent", {})
316
+
317
+ user_agent = {**user_agent, "file_type": "config"}
318
+ user_agent = http_user_agent(user_agent)
319
+
320
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
321
+
322
+ if cls.config_name is None:
323
+ raise ValueError(
324
+ "`self.config_name` is not defined. Note that one should not load a config from "
325
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
326
+ )
327
+
328
+ if os.path.isfile(pretrained_model_name_or_path):
329
+ config_file = pretrained_model_name_or_path
330
+ elif os.path.isdir(pretrained_model_name_or_path):
331
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
332
+ # Load from a PyTorch checkpoint
333
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
334
+ elif subfolder is not None and os.path.isfile(
335
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
336
+ ):
337
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
338
+ else:
339
+ raise EnvironmentError(
340
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
341
+ )
342
+ else:
343
+ try:
344
+ # Load from URL or cache if already cached
345
+ config_file = hf_hub_download(
346
+ pretrained_model_name_or_path,
347
+ filename=cls.config_name,
348
+ cache_dir=cache_dir,
349
+ force_download=force_download,
350
+ proxies=proxies,
351
+ resume_download=resume_download,
352
+ local_files_only=local_files_only,
353
+ use_auth_token=use_auth_token,
354
+ user_agent=user_agent,
355
+ subfolder=subfolder,
356
+ revision=revision,
357
+ )
358
+ except RepositoryNotFoundError:
359
+ raise EnvironmentError(
360
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
361
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
362
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
363
+ " login`."
364
+ )
365
+ except RevisionNotFoundError:
366
+ raise EnvironmentError(
367
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
368
+ " this model name. Check the model page at"
369
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
370
+ )
371
+ except EntryNotFoundError:
372
+ raise EnvironmentError(
373
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
374
+ )
375
+ except HTTPError as err:
376
+ raise EnvironmentError(
377
+ "There was a specific connection error when trying to load"
378
+ f" {pretrained_model_name_or_path}:\n{err}"
379
+ )
380
+ except ValueError:
381
+ raise EnvironmentError(
382
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
383
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
384
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
385
+ " run the library in offline mode at"
386
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
387
+ )
388
+ except EnvironmentError:
389
+ raise EnvironmentError(
390
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
391
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
392
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
393
+ f"containing a {cls.config_name} file"
394
+ )
395
+
396
+ try:
397
+ # Load config dict
398
+ config_dict = cls._dict_from_json_file(config_file)
399
+
400
+ commit_hash = extract_commit_hash(config_file)
401
+ except (json.JSONDecodeError, UnicodeDecodeError):
402
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
403
+
404
+ if not (return_unused_kwargs or return_commit_hash):
405
+ return config_dict
406
+
407
+ outputs = (config_dict,)
408
+
409
+ if return_unused_kwargs:
410
+ outputs += (kwargs,)
411
+
412
+ if return_commit_hash:
413
+ outputs += (commit_hash,)
414
+
415
+ return outputs
416
+
417
+ @staticmethod
418
+ def _get_init_keys(cls):
419
+ return set(dict(inspect.signature(cls.__init__).parameters).keys())
420
+
421
+ @classmethod
422
+ def extract_init_dict(cls, config_dict, **kwargs):
423
+ # 0. Copy origin config dict
424
+ original_dict = dict(config_dict.items())
425
+
426
+ # 1. Retrieve expected config attributes from __init__ signature
427
+ expected_keys = cls._get_init_keys(cls)
428
+ expected_keys.remove("self")
429
+ # remove general kwargs if present in dict
430
+ if "kwargs" in expected_keys:
431
+ expected_keys.remove("kwargs")
432
+ # remove flax internal keys
433
+ if hasattr(cls, "_flax_internal_args"):
434
+ for arg in cls._flax_internal_args:
435
+ expected_keys.remove(arg)
436
+
437
+ # 2. Remove attributes that cannot be expected from expected config attributes
438
+ # remove keys to be ignored
439
+ if len(cls.ignore_for_config) > 0:
440
+ expected_keys = expected_keys - set(cls.ignore_for_config)
441
+
442
+ # load diffusers library to import compatible and original scheduler
443
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
444
+
445
+ if cls.has_compatibles:
446
+ compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
447
+ else:
448
+ compatible_classes = []
449
+
450
+ expected_keys_comp_cls = set()
451
+ for c in compatible_classes:
452
+ expected_keys_c = cls._get_init_keys(c)
453
+ expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
454
+ expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
455
+ config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
456
+
457
+ # remove attributes from orig class that cannot be expected
458
+ orig_cls_name = config_dict.pop("_class_name", cls.__name__)
459
+ if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
460
+ orig_cls = getattr(diffusers_library, orig_cls_name)
461
+ unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
462
+ config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
463
+
464
+ # remove private attributes
465
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
466
+
467
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
468
+ init_dict = {}
469
+ for key in expected_keys:
470
+ # if config param is passed to kwarg and is present in config dict
471
+ # it should overwrite existing config dict key
472
+ if key in kwargs and key in config_dict:
473
+ config_dict[key] = kwargs.pop(key)
474
+
475
+ if key in kwargs:
476
+ # overwrite key
477
+ init_dict[key] = kwargs.pop(key)
478
+ elif key in config_dict:
479
+ # use value from config dict
480
+ init_dict[key] = config_dict.pop(key)
481
+
482
+ # 4. Give nice warning if unexpected values have been passed
483
+ if len(config_dict) > 0:
484
+ logger.warning(
485
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
486
+ "but are not expected and will be ignored. Please verify your "
487
+ f"{cls.config_name} configuration file."
488
+ )
489
+
490
+ # 5. Give nice info if config attributes are initiliazed to default because they have not been passed
491
+ passed_keys = set(init_dict.keys())
492
+ if len(expected_keys - passed_keys) > 0:
493
+ logger.info(
494
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
495
+ )
496
+
497
+ # 6. Define unused keyword arguments
498
+ unused_kwargs = {**config_dict, **kwargs}
499
+
500
+ # 7. Define "hidden" config parameters that were saved for compatible classes
501
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
502
+
503
+ return init_dict, unused_kwargs, hidden_config_dict
504
+
505
+ @classmethod
506
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
507
+ with open(json_file, "r", encoding="utf-8") as reader:
508
+ text = reader.read()
509
+ return json.loads(text)
510
+
511
+ def __repr__(self):
512
+ return f"{self.__class__.__name__} {self.to_json_string()}"
513
+
514
+ @property
515
+ def config(self) -> Dict[str, Any]:
516
+ """
517
+ Returns the config of the class as a frozen dictionary
518
+
519
+ Returns:
520
+ `Dict[str, Any]`: Config of the class.
521
+ """
522
+ return self._internal_dict
523
+
524
+ def to_json_string(self) -> str:
525
+ """
526
+ Serializes the configuration instance to a JSON string.
527
+
528
+ Returns:
529
+ `str`:
530
+ String containing all the attributes that make up the configuration instance in JSON format.
531
+ """
532
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
533
+ config_dict["_class_name"] = self.__class__.__name__
534
+ config_dict["_diffusers_version"] = __version__
535
+
536
+ def to_json_saveable(value):
537
+ if isinstance(value, np.ndarray):
538
+ value = value.tolist()
539
+ elif isinstance(value, PosixPath):
540
+ value = str(value)
541
+ return value
542
+
543
+ config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
544
+ # Don't save "_ignore_files"
545
+ config_dict.pop("_ignore_files", None)
546
+
547
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
548
+
549
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
550
+ """
551
+ Save the configuration instance's parameters to a JSON file.
552
+
553
+ Args:
554
+ json_file_path (`str` or `os.PathLike`):
555
+ Path to the JSON file to save a configuration instance's parameters.
556
+ """
557
+ with open(json_file_path, "w", encoding="utf-8") as writer:
558
+ writer.write(self.to_json_string())
559
+
560
+
561
+ def register_to_config(init):
562
+ r"""
563
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
564
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
565
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
566
+
567
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
568
+ """
569
+
570
+ @functools.wraps(init)
571
+ def inner_init(self, *args, **kwargs):
572
+ # Ignore private kwargs in the init.
573
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
574
+ config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
575
+ if not isinstance(self, ConfigMixin):
576
+ raise RuntimeError(
577
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
578
+ "not inherit from `ConfigMixin`."
579
+ )
580
+
581
+ ignore = getattr(self, "ignore_for_config", [])
582
+ # Get positional arguments aligned with kwargs
583
+ new_kwargs = {}
584
+ signature = inspect.signature(init)
585
+ parameters = {
586
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
587
+ }
588
+ for arg, name in zip(args, parameters.keys()):
589
+ new_kwargs[name] = arg
590
+
591
+ # Then add all kwargs
592
+ new_kwargs.update(
593
+ {
594
+ k: init_kwargs.get(k, default)
595
+ for k, default in parameters.items()
596
+ if k not in ignore and k not in new_kwargs
597
+ }
598
+ )
599
+ new_kwargs = {**config_init_kwargs, **new_kwargs}
600
+ getattr(self, "register_to_config")(**new_kwargs)
601
+ init(self, *args, **init_kwargs)
602
+
603
+ return inner_init
604
+
605
+
606
+ def flax_register_to_config(cls):
607
+ original_init = cls.__init__
608
+
609
+ @functools.wraps(original_init)
610
+ def init(self, *args, **kwargs):
611
+ if not isinstance(self, ConfigMixin):
612
+ raise RuntimeError(
613
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
614
+ "not inherit from `ConfigMixin`."
615
+ )
616
+
617
+ # Ignore private kwargs in the init. Retrieve all passed attributes
618
+ init_kwargs = dict(kwargs.items())
619
+
620
+ # Retrieve default values
621
+ fields = dataclasses.fields(self)
622
+ default_kwargs = {}
623
+ for field in fields:
624
+ # ignore flax specific attributes
625
+ if field.name in self._flax_internal_args:
626
+ continue
627
+ if type(field.default) == dataclasses._MISSING_TYPE:
628
+ default_kwargs[field.name] = None
629
+ else:
630
+ default_kwargs[field.name] = getattr(self, field.name)
631
+
632
+ # Make sure init_kwargs override default kwargs
633
+ new_kwargs = {**default_kwargs, **init_kwargs}
634
+ # dtype should be part of `init_kwargs`, but not `new_kwargs`
635
+ if "dtype" in new_kwargs:
636
+ new_kwargs.pop("dtype")
637
+
638
+ # Get positional arguments aligned with kwargs
639
+ for i, arg in enumerate(args):
640
+ name = fields[i].name
641
+ new_kwargs[name] = arg
642
+
643
+ getattr(self, "register_to_config")(**new_kwargs)
644
+ original_init(self, *args, **kwargs)
645
+
646
+ cls.__init__ = init
647
+ return cls
diffusers/utils/constants.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+
17
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
18
+
19
+
20
+ default_cache_path = HUGGINGFACE_HUB_CACHE
21
+
22
+
23
+ CONFIG_NAME = "config.json"
24
+ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
25
+ FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
26
+ ONNX_WEIGHTS_NAME = "model.onnx"
27
+ SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
28
+ ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
29
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
30
+ DIFFUSERS_CACHE = default_cache_path
31
+ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
32
+ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
33
+ DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
34
+ TEXT_ENCODER_ATTN_MODULE = ".self_attn"
diffusers/utils/deprecation_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import warnings
3
+ from typing import Any, Dict, Optional, Union
4
+
5
+ from packaging import version
6
+
7
+
8
+ def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
9
+ from .. import __version__
10
+
11
+ deprecated_kwargs = take_from
12
+ values = ()
13
+ if not isinstance(args[0], tuple):
14
+ args = (args,)
15
+
16
+ for attribute, version_name, message in args:
17
+ if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
18
+ raise ValueError(
19
+ f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
20
+ f" version {__version__} is >= {version_name}"
21
+ )
22
+
23
+ warning = None
24
+ if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
25
+ values += (deprecated_kwargs.pop(attribute),)
26
+ warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
27
+ elif hasattr(deprecated_kwargs, attribute):
28
+ values += (getattr(deprecated_kwargs, attribute),)
29
+ warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
30
+ elif deprecated_kwargs is None:
31
+ warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
32
+
33
+ if warning is not None:
34
+ warning = warning + " " if standard_warn else ""
35
+ warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
36
+
37
+ if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
38
+ call_frame = inspect.getouterframes(inspect.currentframe())[1]
39
+ filename = call_frame.filename
40
+ line_number = call_frame.lineno
41
+ function = call_frame.function
42
+ key, value = next(iter(deprecated_kwargs.items()))
43
+ raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
44
+
45
+ if len(values) == 0:
46
+ return
47
+ elif len(values) == 1:
48
+ return values[0]
49
+ return values
diffusers/utils/hub_utils.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import re
19
+ import sys
20
+ import traceback
21
+ import warnings
22
+ from pathlib import Path
23
+ from typing import Dict, Optional, Union
24
+ from uuid import uuid4
25
+
26
+ from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
27
+ from huggingface_hub.file_download import REGEX_COMMIT_HASH
28
+ from huggingface_hub.utils import (
29
+ EntryNotFoundError,
30
+ RepositoryNotFoundError,
31
+ RevisionNotFoundError,
32
+ is_jinja_available,
33
+ )
34
+ from packaging import version
35
+ from requests import HTTPError
36
+
37
+ from .constants import (
38
+ DEPRECATED_REVISION_ARGS,
39
+ DIFFUSERS_CACHE,
40
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
41
+ SAFETENSORS_WEIGHTS_NAME,
42
+ WEIGHTS_NAME,
43
+ )
44
+ from .import_utils import (
45
+ ENV_VARS_TRUE_VALUES,
46
+ _flax_version,
47
+ _jax_version,
48
+ _onnxruntime_version,
49
+ _torch_version,
50
+ is_flax_available,
51
+ is_onnx_available,
52
+ is_torch_available,
53
+ )
54
+ from .logging import get_logger
55
+
56
+
57
+ logger = get_logger(__name__)
58
+
59
+
60
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
61
+ SESSION_ID = uuid4().hex
62
+ HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
63
+ DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
64
+ HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
65
+
66
+
67
+ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
68
+ """
69
+ Formats a user-agent string with basic info about a request.
70
+ """
71
+ ua = f"diffusers; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
72
+ if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
73
+ return ua + "; telemetry/off"
74
+ if is_torch_available():
75
+ ua += f"; torch/{_torch_version}"
76
+ if is_flax_available():
77
+ ua += f"; jax/{_jax_version}"
78
+ ua += f"; flax/{_flax_version}"
79
+ if is_onnx_available():
80
+ ua += f"; onnxruntime/{_onnxruntime_version}"
81
+ # CI will set this value to True
82
+ if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
83
+ ua += "; is_ci/true"
84
+ if isinstance(user_agent, dict):
85
+ ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
86
+ elif isinstance(user_agent, str):
87
+ ua += "; " + user_agent
88
+ return ua
89
+
90
+
91
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
92
+ if token is None:
93
+ token = HfFolder.get_token()
94
+ if organization is None:
95
+ username = whoami(token)["name"]
96
+ return f"{username}/{model_id}"
97
+ else:
98
+ return f"{organization}/{model_id}"
99
+
100
+
101
+ def create_model_card(args, model_name):
102
+ if not is_jinja_available():
103
+ raise ValueError(
104
+ "Modelcard rendering is based on Jinja templates."
105
+ " Please make sure to have `jinja` installed before using `create_model_card`."
106
+ " To install it, please run `pip install Jinja2`."
107
+ )
108
+
109
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
110
+ return
111
+
112
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
113
+ repo_name = get_full_repo_name(model_name, token=hub_token)
114
+
115
+ model_card = ModelCard.from_template(
116
+ card_data=ModelCardData( # Card metadata object that will be converted to YAML block
117
+ language="en",
118
+ license="apache-2.0",
119
+ library_name="diffusers",
120
+ tags=[],
121
+ datasets=args.dataset_name,
122
+ metrics=[],
123
+ ),
124
+ template_path=MODEL_CARD_TEMPLATE_PATH,
125
+ model_name=model_name,
126
+ repo_name=repo_name,
127
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
128
+ learning_rate=args.learning_rate,
129
+ train_batch_size=args.train_batch_size,
130
+ eval_batch_size=args.eval_batch_size,
131
+ gradient_accumulation_steps=(
132
+ args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
133
+ ),
134
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
135
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
136
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
137
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
138
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
139
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
140
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
141
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
142
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
143
+ mixed_precision=args.mixed_precision,
144
+ )
145
+
146
+ card_path = os.path.join(args.output_dir, "README.md")
147
+ model_card.save(card_path)
148
+
149
+
150
+ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
151
+ """
152
+ Extracts the commit hash from a resolved filename toward a cache file.
153
+ """
154
+ if resolved_file is None or commit_hash is not None:
155
+ return commit_hash
156
+ resolved_file = str(Path(resolved_file).as_posix())
157
+ search = re.search(r"snapshots/([^/]+)/", resolved_file)
158
+ if search is None:
159
+ return None
160
+ commit_hash = search.groups()[0]
161
+ return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
162
+
163
+
164
+ # Old default cache path, potentially to be migrated.
165
+ # This logic was more or less taken from `transformers`, with the following differences:
166
+ # - Diffusers doesn't use custom environment variables to specify the cache path.
167
+ # - There is no need to migrate the cache format, just move the files to the new location.
168
+ hf_cache_home = os.path.expanduser(
169
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
170
+ )
171
+ old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
172
+
173
+
174
+ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
175
+ if new_cache_dir is None:
176
+ new_cache_dir = DIFFUSERS_CACHE
177
+ if old_cache_dir is None:
178
+ old_cache_dir = old_diffusers_cache
179
+
180
+ old_cache_dir = Path(old_cache_dir).expanduser()
181
+ new_cache_dir = Path(new_cache_dir).expanduser()
182
+ for old_blob_path in old_cache_dir.glob("**/blobs/*"):
183
+ if old_blob_path.is_file() and not old_blob_path.is_symlink():
184
+ new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
185
+ new_blob_path.parent.mkdir(parents=True, exist_ok=True)
186
+ os.replace(old_blob_path, new_blob_path)
187
+ try:
188
+ os.symlink(new_blob_path, old_blob_path)
189
+ except OSError:
190
+ logger.warning(
191
+ "Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
192
+ )
193
+ # At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
194
+
195
+
196
+ cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
197
+ if not os.path.isfile(cache_version_file):
198
+ cache_version = 0
199
+ else:
200
+ with open(cache_version_file) as f:
201
+ cache_version = int(f.read())
202
+
203
+ if cache_version < 1:
204
+ old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
205
+ if old_cache_is_not_empty:
206
+ logger.warning(
207
+ "The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
208
+ "existing cached models. This is a one-time operation, you can interrupt it or run it "
209
+ "later by calling `diffusers.utils.hub_utils.move_cache()`."
210
+ )
211
+ try:
212
+ move_cache()
213
+ except Exception as e:
214
+ trace = "\n".join(traceback.format_tb(e.__traceback__))
215
+ logger.error(
216
+ f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
217
+ "file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
218
+ "message and we will do our best to help."
219
+ )
220
+
221
+ if cache_version < 1:
222
+ try:
223
+ os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
224
+ with open(cache_version_file, "w") as f:
225
+ f.write("1")
226
+ except Exception:
227
+ logger.warning(
228
+ f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
229
+ "the directory exists and can be written to."
230
+ )
231
+
232
+
233
+ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
234
+ if variant is not None:
235
+ splits = weights_name.split(".")
236
+ splits = splits[:-1] + [variant] + splits[-1:]
237
+ weights_name = ".".join(splits)
238
+
239
+ return weights_name
240
+
241
+
242
+ def _get_model_file(
243
+ pretrained_model_name_or_path,
244
+ *,
245
+ weights_name,
246
+ subfolder,
247
+ cache_dir,
248
+ force_download,
249
+ proxies,
250
+ resume_download,
251
+ local_files_only,
252
+ use_auth_token,
253
+ user_agent,
254
+ revision,
255
+ commit_hash=None,
256
+ ):
257
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
258
+ if os.path.isfile(pretrained_model_name_or_path):
259
+ return pretrained_model_name_or_path
260
+ elif os.path.isdir(pretrained_model_name_or_path):
261
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
262
+ # Load from a PyTorch checkpoint
263
+ model_file = os.path.join(pretrained_model_name_or_path, weights_name)
264
+ return model_file
265
+ elif subfolder is not None and os.path.isfile(
266
+ os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
267
+ ):
268
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
269
+ return model_file
270
+ else:
271
+ raise EnvironmentError(
272
+ f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
273
+ )
274
+ else:
275
+ # 1. First check if deprecated way of loading from branches is used
276
+ if (
277
+ revision in DEPRECATED_REVISION_ARGS
278
+ and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
279
+ and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
280
+ ):
281
+ try:
282
+ model_file = hf_hub_download(
283
+ pretrained_model_name_or_path,
284
+ filename=_add_variant(weights_name, revision),
285
+ cache_dir=cache_dir,
286
+ force_download=force_download,
287
+ proxies=proxies,
288
+ resume_download=resume_download,
289
+ local_files_only=local_files_only,
290
+ use_auth_token=use_auth_token,
291
+ user_agent=user_agent,
292
+ subfolder=subfolder,
293
+ revision=revision or commit_hash,
294
+ )
295
+ warnings.warn(
296
+ f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
297
+ FutureWarning,
298
+ )
299
+ return model_file
300
+ except: # noqa: E722
301
+ warnings.warn(
302
+ f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
303
+ FutureWarning,
304
+ )
305
+ try:
306
+ # 2. Load model file as usual
307
+ model_file = hf_hub_download(
308
+ pretrained_model_name_or_path,
309
+ filename=weights_name,
310
+ cache_dir=cache_dir,
311
+ force_download=force_download,
312
+ proxies=proxies,
313
+ resume_download=resume_download,
314
+ local_files_only=local_files_only,
315
+ use_auth_token=use_auth_token,
316
+ user_agent=user_agent,
317
+ subfolder=subfolder,
318
+ revision=revision or commit_hash,
319
+ )
320
+ return model_file
321
+
322
+ except RepositoryNotFoundError:
323
+ raise EnvironmentError(
324
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
325
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
326
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
327
+ "login`."
328
+ )
329
+ except RevisionNotFoundError:
330
+ raise EnvironmentError(
331
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
332
+ "this model name. Check the model page at "
333
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
334
+ )
335
+ except EntryNotFoundError:
336
+ raise EnvironmentError(
337
+ f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
338
+ )
339
+ except HTTPError as err:
340
+ raise EnvironmentError(
341
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
342
+ )
343
+ except ValueError:
344
+ raise EnvironmentError(
345
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
346
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
347
+ f" directory containing a file named {weights_name} or"
348
+ " \nCheckout your internet connection or see how to run the library in"
349
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
350
+ )
351
+ except EnvironmentError:
352
+ raise EnvironmentError(
353
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
354
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
355
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
356
+ f"containing a file named {weights_name}"
357
+ )
diffusers/utils/import_utils.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Import utilities: Utilities related to imports and our lazy inits.
16
+ """
17
+
18
+ import importlib.util
19
+ import operator as op
20
+ import os
21
+ import sys
22
+ from collections import OrderedDict
23
+ from typing import Union
24
+
25
+ from packaging import version
26
+ from packaging.version import Version, parse
27
+
28
+ from . import logging
29
+
30
+
31
+ # The package importlib_metadata is in a different place, depending on the python version.
32
+ if sys.version_info < (3, 8):
33
+ import importlib_metadata
34
+ else:
35
+ import importlib.metadata as importlib_metadata
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
41
+ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
42
+
43
+ USE_TF = os.environ.get("USE_TF", "AUTO").upper()
44
+ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
45
+ USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
46
+ USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
47
+
48
+ STR_OPERATION_TO_FUNC = {
49
+ ">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt
50
+ }
51
+
52
+ _torch_version = "N/A"
53
+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
54
+ _torch_available = importlib.util.find_spec("torch") is not None
55
+ if _torch_available:
56
+ try:
57
+ _torch_version = importlib_metadata.version("torch")
58
+ logger.info(f"PyTorch version {_torch_version} available.")
59
+ except importlib_metadata.PackageNotFoundError:
60
+ _torch_available = False
61
+ else:
62
+ logger.info("Disabling PyTorch because USE_TORCH is set")
63
+ _torch_available = False
64
+
65
+
66
+ _tf_version = "N/A"
67
+ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
68
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
69
+ if _tf_available:
70
+ candidates = (
71
+ "tensorflow",
72
+ "tensorflow-cpu",
73
+ "tensorflow-gpu",
74
+ "tf-nightly",
75
+ "tf-nightly-cpu",
76
+ "tf-nightly-gpu",
77
+ "intel-tensorflow",
78
+ "intel-tensorflow-avx512",
79
+ "tensorflow-rocm",
80
+ "tensorflow-macos",
81
+ "tensorflow-aarch64",
82
+ )
83
+ _tf_version = None
84
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
85
+ for pkg in candidates:
86
+ try:
87
+ _tf_version = importlib_metadata.version(pkg)
88
+ break
89
+ except importlib_metadata.PackageNotFoundError:
90
+ pass
91
+ _tf_available = _tf_version is not None
92
+ if _tf_available:
93
+ if version.parse(_tf_version) < version.parse("2"):
94
+ logger.info(f"TensorFlow found but with version {_tf_version}. "
95
+ "Diffusers requires version 2 minimum.")
96
+ _tf_available = False
97
+ else:
98
+ logger.info(f"TensorFlow version {_tf_version} available.")
99
+ else:
100
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
101
+ _tf_available = False
102
+
103
+ _jax_version = "N/A"
104
+ _flax_version = "N/A"
105
+ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
106
+ _flax_available = importlib.util.find_spec("jax") is not None \
107
+ and importlib.util.find_spec("flax") is not None
108
+ if _flax_available:
109
+ try:
110
+ _jax_version = importlib_metadata.version("jax")
111
+ _flax_version = importlib_metadata.version("flax")
112
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
113
+ except importlib_metadata.PackageNotFoundError:
114
+ _flax_available = False
115
+ else:
116
+ _flax_available = False
117
+
118
+ if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
119
+ _safetensors_available = importlib.util.find_spec("safetensors") is not None
120
+ if _safetensors_available:
121
+ try:
122
+ _safetensors_version = importlib_metadata.version("safetensors")
123
+ logger.info(f"Safetensors version {_safetensors_version} available.")
124
+ except importlib_metadata.PackageNotFoundError:
125
+ _safetensors_available = False
126
+ else:
127
+ logger.info("Disabling Safetensors because USE_TF is set")
128
+ _safetensors_available = False
129
+
130
+ _transformers_available = importlib.util.find_spec("transformers") is not None
131
+ try:
132
+ _transformers_version = importlib_metadata.version("transformers")
133
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
134
+ except importlib_metadata.PackageNotFoundError:
135
+ _transformers_available = False
136
+
137
+
138
+ _inflect_available = importlib.util.find_spec("inflect") is not None
139
+ try:
140
+ _inflect_version = importlib_metadata.version("inflect")
141
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
142
+ except importlib_metadata.PackageNotFoundError:
143
+ _inflect_available = False
144
+
145
+
146
+ _unidecode_available = importlib.util.find_spec("unidecode") is not None
147
+ try:
148
+ _unidecode_version = importlib_metadata.version("unidecode")
149
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
150
+ except importlib_metadata.PackageNotFoundError:
151
+ _unidecode_available = False
152
+
153
+
154
+ _onnxruntime_version = "N/A"
155
+ _onnx_available = importlib.util.find_spec("onnxruntime") is not None
156
+ if _onnx_available:
157
+ candidates = (
158
+ "onnxruntime",
159
+ "onnxruntime-gpu",
160
+ "ort_nightly_gpu",
161
+ "onnxruntime-directml",
162
+ "onnxruntime-openvino",
163
+ "ort_nightly_directml",
164
+ "onnxruntime-rocm",
165
+ "onnxruntime-training",
166
+ )
167
+ _onnxruntime_version = None
168
+ # For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
169
+ for pkg in candidates:
170
+ try:
171
+ _onnxruntime_version = importlib_metadata.version(pkg)
172
+ break
173
+ except importlib_metadata.PackageNotFoundError:
174
+ pass
175
+ _onnx_available = _onnxruntime_version is not None
176
+ if _onnx_available:
177
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
178
+
179
+ # (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
180
+ # _opencv_available = importlib.util.find_spec("opencv-python") is not None
181
+ try:
182
+ candidates = (
183
+ "opencv-python",
184
+ "opencv-contrib-python",
185
+ "opencv-python-headless",
186
+ "opencv-contrib-python-headless",
187
+ )
188
+ _opencv_version = None
189
+ for pkg in candidates:
190
+ try:
191
+ _opencv_version = importlib_metadata.version(pkg)
192
+ break
193
+ except importlib_metadata.PackageNotFoundError:
194
+ pass
195
+ _opencv_available = _opencv_version is not None
196
+ if _opencv_available:
197
+ logger.debug(f"Successfully imported cv2 version {_opencv_version}")
198
+ except importlib_metadata.PackageNotFoundError:
199
+ _opencv_available = False
200
+
201
+ _scipy_available = importlib.util.find_spec("scipy") is not None
202
+ try:
203
+ _scipy_version = importlib_metadata.version("scipy")
204
+ logger.debug(f"Successfully imported scipy version {_scipy_version}")
205
+ except importlib_metadata.PackageNotFoundError:
206
+ _scipy_available = False
207
+
208
+ _librosa_available = importlib.util.find_spec("librosa") is not None
209
+ try:
210
+ _librosa_version = importlib_metadata.version("librosa")
211
+ logger.debug(f"Successfully imported librosa version {_librosa_version}")
212
+ except importlib_metadata.PackageNotFoundError:
213
+ _librosa_available = False
214
+
215
+ _accelerate_available = importlib.util.find_spec("accelerate") is not None
216
+ try:
217
+ _accelerate_version = importlib_metadata.version("accelerate")
218
+ logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
219
+ except importlib_metadata.PackageNotFoundError:
220
+ _accelerate_available = False
221
+
222
+ _xformers_available = importlib.util.find_spec("xformers") is not None
223
+ try:
224
+ _xformers_version = importlib_metadata.version("xformers")
225
+ if _torch_available:
226
+ import torch
227
+
228
+ if version.Version(torch.__version__) < version.Version("1.12"):
229
+ raise ValueError("PyTorch should be >= 1.12")
230
+ logger.debug(f"Successfully imported xformers version {_xformers_version}")
231
+ except importlib_metadata.PackageNotFoundError:
232
+ _xformers_available = False
233
+
234
+ _k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
235
+ try:
236
+ _k_diffusion_version = importlib_metadata.version("k_diffusion")
237
+ logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
238
+ except importlib_metadata.PackageNotFoundError:
239
+ _k_diffusion_available = False
240
+
241
+ _note_seq_available = importlib.util.find_spec("note_seq") is not None
242
+ try:
243
+ _note_seq_version = importlib_metadata.version("note_seq")
244
+ logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
245
+ except importlib_metadata.PackageNotFoundError:
246
+ _note_seq_available = False
247
+
248
+ _wandb_available = importlib.util.find_spec("wandb") is not None
249
+ try:
250
+ _wandb_version = importlib_metadata.version("wandb")
251
+ logger.debug(f"Successfully imported wandb version {_wandb_version }")
252
+ except importlib_metadata.PackageNotFoundError:
253
+ _wandb_available = False
254
+
255
+ _omegaconf_available = importlib.util.find_spec("omegaconf") is not None
256
+ try:
257
+ _omegaconf_version = importlib_metadata.version("omegaconf")
258
+ logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}")
259
+ except importlib_metadata.PackageNotFoundError:
260
+ _omegaconf_available = False
261
+
262
+ _tensorboard_available = importlib.util.find_spec("tensorboard")
263
+ try:
264
+ _tensorboard_version = importlib_metadata.version("tensorboard")
265
+ logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
266
+ except importlib_metadata.PackageNotFoundError:
267
+ _tensorboard_available = False
268
+
269
+
270
+ _compel_available = importlib.util.find_spec("compel")
271
+ try:
272
+ _compel_version = importlib_metadata.version("compel")
273
+ logger.debug(f"Successfully imported compel version {_compel_version}")
274
+ except importlib_metadata.PackageNotFoundError:
275
+ _compel_available = False
276
+
277
+
278
+ _ftfy_available = importlib.util.find_spec("ftfy") is not None
279
+ try:
280
+ _ftfy_version = importlib_metadata.version("ftfy")
281
+ logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
282
+ except importlib_metadata.PackageNotFoundError:
283
+ _ftfy_available = False
284
+
285
+
286
+ _bs4_available = importlib.util.find_spec("bs4") is not None
287
+ try:
288
+ # importlib metadata under different name
289
+ _bs4_version = importlib_metadata.version("beautifulsoup4")
290
+ logger.debug(f"Successfully imported ftfy version {_bs4_version}")
291
+ except importlib_metadata.PackageNotFoundError:
292
+ _bs4_available = False
293
+
294
+ _torchsde_available = importlib.util.find_spec("torchsde") is not None
295
+ try:
296
+ _torchsde_version = importlib_metadata.version("torchsde")
297
+ logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
298
+ except importlib_metadata.PackageNotFoundError:
299
+ _torchsde_available = False
300
+
301
+
302
+ def is_torch_available():
303
+ return _torch_available
304
+
305
+
306
+ def is_safetensors_available():
307
+ return _safetensors_available
308
+
309
+
310
+ def is_tf_available():
311
+ return _tf_available
312
+
313
+
314
+ def is_flax_available():
315
+ return _flax_available
316
+
317
+
318
+ def is_transformers_available():
319
+ return _transformers_available
320
+
321
+
322
+ def is_inflect_available():
323
+ return _inflect_available
324
+
325
+
326
+ def is_unidecode_available():
327
+ return _unidecode_available
328
+
329
+
330
+ def is_onnx_available():
331
+ return _onnx_available
332
+
333
+
334
+ def is_opencv_available():
335
+ return _opencv_available
336
+
337
+
338
+ def is_scipy_available():
339
+ return _scipy_available
340
+
341
+
342
+ def is_librosa_available():
343
+ return _librosa_available
344
+
345
+
346
+ def is_xformers_available():
347
+ return _xformers_available
348
+
349
+
350
+ def is_accelerate_available():
351
+ return _accelerate_available
352
+
353
+
354
+ def is_k_diffusion_available():
355
+ return _k_diffusion_available
356
+
357
+
358
+ def is_note_seq_available():
359
+ return _note_seq_available
360
+
361
+
362
+ def is_wandb_available():
363
+ return _wandb_available
364
+
365
+
366
+ def is_omegaconf_available():
367
+ return _omegaconf_available
368
+
369
+
370
+ def is_tensorboard_available():
371
+ return _tensorboard_available
372
+
373
+
374
+ def is_compel_available():
375
+ return _compel_available
376
+
377
+
378
+ def is_ftfy_available():
379
+ return _ftfy_available
380
+
381
+
382
+ def is_bs4_available():
383
+ return _bs4_available
384
+
385
+
386
+ def is_torchsde_available():
387
+ return _torchsde_available
388
+
389
+
390
+ # docstyle-ignore
391
+ FLAX_IMPORT_ERROR = """
392
+ {0} requires the FLAX library but it was not found in your environment.
393
+ Checkout the instructions on the installation page: https://github.com/google/flax
394
+ and follow the ones that match your environment.
395
+ """
396
+
397
+ # docstyle-ignore
398
+ INFLECT_IMPORT_ERROR = """
399
+ {0} requires the inflect library but it was not found in your environment.
400
+ You can install it with pip: `pip install inflect`
401
+ """
402
+
403
+ # docstyle-ignore
404
+ PYTORCH_IMPORT_ERROR = """
405
+ {0} requires the PyTorch library but it was not found in your environment.
406
+ Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/
407
+ and follow the ones that match your environment.
408
+ """
409
+
410
+ # docstyle-ignore
411
+ ONNX_IMPORT_ERROR = """
412
+ {0} requires the onnxruntime library but it was not found in your environment.
413
+ You can install it with pip: `pip install onnxruntime`
414
+ """
415
+
416
+ # docstyle-ignore
417
+ OPENCV_IMPORT_ERROR = """
418
+ {0} requires the OpenCV library but it was not found in your environment.
419
+ You can install it with pip: `pip install opencv-python`
420
+ """
421
+
422
+ # docstyle-ignore
423
+ SCIPY_IMPORT_ERROR = """
424
+ {0} requires the scipy library but it was not found in your environment.
425
+ You can install it with pip: `pip install scipy`
426
+ """
427
+
428
+ # docstyle-ignore
429
+ LIBROSA_IMPORT_ERROR = """
430
+ {0} requires the librosa library but it was not found in your environment.
431
+ Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html
432
+ and follow the ones that match your environment.
433
+ """
434
+
435
+ # docstyle-ignore
436
+ TRANSFORMERS_IMPORT_ERROR = """
437
+ {0} requires the transformers library but it was not found in your environment.
438
+ You can install it with pip: `pip install transformers`
439
+ """
440
+
441
+ # docstyle-ignore
442
+ UNIDECODE_IMPORT_ERROR = """
443
+ {0} requires the unidecode library but it was not found in your environment.
444
+ You can install it with pip: `pip install Unidecode`
445
+ """
446
+
447
+ # docstyle-ignore
448
+ K_DIFFUSION_IMPORT_ERROR = """
449
+ {0} requires the k-diffusion library but it was not found in your environment.
450
+ You can install it with pip: `pip install k-diffusion`
451
+ """
452
+
453
+ # docstyle-ignore
454
+ NOTE_SEQ_IMPORT_ERROR = """
455
+ {0} requires the note-seq library but it was not found in your environment.
456
+ You can install it with pip: `pip install note-seq`
457
+ """
458
+
459
+ # docstyle-ignore
460
+ WANDB_IMPORT_ERROR = """
461
+ {0} requires the wandb library but it was not found in your environment.
462
+ You can install it with pip: `pip install wandb`
463
+ """
464
+
465
+ # docstyle-ignore
466
+ OMEGACONF_IMPORT_ERROR = """
467
+ {0} requires the omegaconf library but it was not found in your environment.
468
+ You can install it with pip: `pip install omegaconf`
469
+ """
470
+
471
+ # docstyle-ignore
472
+ TENSORBOARD_IMPORT_ERROR = """
473
+ {0} requires the tensorboard library but it was not found in your environment.
474
+ You can install it with pip: `pip install tensorboard`
475
+ """
476
+
477
+
478
+ # docstyle-ignore
479
+ COMPEL_IMPORT_ERROR = """
480
+ {0} requires the compel library but it was not found in your environment.
481
+ You can install it with pip: `pip install compel`
482
+ """
483
+
484
+ # docstyle-ignore
485
+ BS4_IMPORT_ERROR = """
486
+ {0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
487
+ `pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
488
+ """
489
+
490
+ # docstyle-ignore
491
+ FTFY_IMPORT_ERROR = """
492
+ {0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
493
+ installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
494
+ that match your environment. Please note that you may need to restart your runtime after installation.
495
+ """
496
+
497
+ # docstyle-ignore
498
+ TORCHSDE_IMPORT_ERROR = """
499
+ {0} requires the torchsde library but it was not found in your environment.
500
+ You can install it with pip: `pip install torchsde`
501
+ """
502
+
503
+
504
+ BACKENDS_MAPPING = OrderedDict(
505
+ [
506
+ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
507
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
508
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
509
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
510
+ ("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
511
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
512
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
513
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
514
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
515
+ ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
516
+ ("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
517
+ ("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
518
+ ("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
519
+ ("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
520
+ ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
521
+ ("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
522
+ ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
523
+ ("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)),
524
+ ]
525
+ )
526
+
527
+
528
+ def requires_backends(obj, backends):
529
+ if not isinstance(backends, (list, tuple)):
530
+ backends = [backends]
531
+
532
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
533
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
534
+ failed = [msg.format(name) for available, msg in checks if not available()]
535
+ if failed:
536
+ raise ImportError("".join(failed))
537
+
538
+ if name in [
539
+ "VersatileDiffusionTextToImagePipeline",
540
+ "VersatileDiffusionPipeline",
541
+ "VersatileDiffusionDualGuidedPipeline",
542
+ "StableDiffusionImageVariationPipeline",
543
+ "UnCLIPPipeline",
544
+ ] and is_transformers_version("<", "4.25.0"):
545
+ raise ImportError(
546
+ f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
547
+ " --upgrade transformers \n```"
548
+ )
549
+
550
+ if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
551
+ "<", "4.26.0"
552
+ ):
553
+ raise ImportError(
554
+ f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
555
+ " --upgrade transformers \n```"
556
+ )
557
+
558
+
559
+ class DummyObject(type):
560
+ """
561
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
562
+ `requires_backend` each time a user tries to access any method of that class.
563
+ """
564
+
565
+ def __getattr__(cls, key):
566
+ if key.startswith("_"):
567
+ return super().__getattr__(cls, key)
568
+ requires_backends(cls, cls._backends)
569
+
570
+
571
+ # This function was copied from:
572
+ # https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
573
+ def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
574
+ """
575
+ Args:
576
+ Compares a library version to some requirement using a given operation.
577
+ library_or_version (`str` or `packaging.version.Version`):
578
+ A library name or a version to check.
579
+ operation (`str`):
580
+ A string representation of an operator, such as `">"` or `"<="`.
581
+ requirement_version (`str`):
582
+ The version to compare the library version against
583
+ """
584
+ if operation not in STR_OPERATION_TO_FUNC.keys():
585
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
586
+ operation = STR_OPERATION_TO_FUNC[operation]
587
+ if isinstance(library_or_version, str):
588
+ library_or_version = parse(importlib_metadata.version(library_or_version))
589
+ return operation(library_or_version, parse(requirement_version))
590
+
591
+
592
+ # This function was copied from:
593
+ # https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
594
+ def is_torch_version(operation: str, version: str):
595
+ """
596
+ Args:
597
+ Compares the current PyTorch version to a given reference with an operation.
598
+ operation (`str`):
599
+ A string representation of an operator, such as `">"` or `"<="`
600
+ version (`str`):
601
+ A string version of PyTorch
602
+ """
603
+ return compare_versions(parse(_torch_version), operation, version)
604
+
605
+
606
+ def is_transformers_version(operation: str, version: str):
607
+ """
608
+ Args:
609
+ Compares the current Transformers version to a given reference with an operation.
610
+ operation (`str`):
611
+ A string representation of an operator, such as `">"` or `"<="`
612
+ version (`str`):
613
+ A version string
614
+ """
615
+ if not _transformers_available:
616
+ return False
617
+ return compare_versions(parse(_transformers_version), operation, version)
618
+
619
+
620
+ def is_accelerate_version(operation: str, version: str):
621
+ """
622
+ Args:
623
+ Compares the current Accelerate version to a given reference with an operation.
624
+ operation (`str`):
625
+ A string representation of an operator, such as `">"` or `"<="`
626
+ version (`str`):
627
+ A version string
628
+ """
629
+ if not _accelerate_available:
630
+ return False
631
+ return compare_versions(parse(_accelerate_version), operation, version)
632
+
633
+
634
+ def is_k_diffusion_version(operation: str, version: str):
635
+ """
636
+ Args:
637
+ Compares the current k-diffusion version to a given reference with an operation.
638
+ operation (`str`):
639
+ A string representation of an operator, such as `">"` or `"<="`
640
+ version (`str`):
641
+ A version string
642
+ """
643
+ if not _k_diffusion_available:
644
+ return False
645
+ return compare_versions(parse(_k_diffusion_version), operation, version)
646
+
647
+
648
+ class OptionalDependencyNotAvailable(BaseException):
649
+ """An error indicating that an optional dependency of Diffusers was not found in the environment."""
diffusers/utils/logging.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 Optuna, Hugging Face
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Logging utilities."""
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import threading
21
+ from logging import (
22
+ CRITICAL, # NOQA
23
+ DEBUG, # NOQA
24
+ ERROR, # NOQA
25
+ FATAL, # NOQA
26
+ INFO, # NOQA
27
+ NOTSET, # NOQA
28
+ WARN, # NOQA
29
+ WARNING, # NOQA
30
+ )
31
+ from typing import Optional
32
+
33
+ from tqdm import auto as tqdm_lib
34
+
35
+
36
+ _lock = threading.Lock()
37
+ _default_handler: Optional[logging.Handler] = None
38
+
39
+ log_levels = {
40
+ "debug": logging.DEBUG,
41
+ "info": logging.INFO,
42
+ "warning": logging.WARNING,
43
+ "error": logging.ERROR,
44
+ "critical": logging.CRITICAL,
45
+ }
46
+
47
+ _default_log_level = logging.WARNING
48
+
49
+ _tqdm_active = True
50
+
51
+
52
+ def _get_default_logging_level():
53
+ """
54
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
55
+ not - fall back to `_default_log_level`
56
+ """
57
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
58
+ if env_level_str:
59
+ if env_level_str in log_levels:
60
+ return log_levels[env_level_str]
61
+ else:
62
+ logging.getLogger().warning(
63
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
64
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
65
+ )
66
+ return _default_log_level
67
+
68
+
69
+ def _get_library_name() -> str:
70
+ return __name__.split(".")[0]
71
+
72
+
73
+ def _get_library_root_logger() -> logging.Logger:
74
+ return logging.getLogger(_get_library_name())
75
+
76
+
77
+ def _configure_library_root_logger() -> None:
78
+ global _default_handler
79
+
80
+ with _lock:
81
+ if _default_handler:
82
+ # This library has already configured the library root logger.
83
+ return
84
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
85
+ _default_handler.flush = sys.stderr.flush
86
+
87
+ # Apply our default configuration to the library root logger.
88
+ library_root_logger = _get_library_root_logger()
89
+ library_root_logger.addHandler(_default_handler)
90
+ library_root_logger.setLevel(_get_default_logging_level())
91
+ library_root_logger.propagate = False
92
+
93
+
94
+ def _reset_library_root_logger() -> None:
95
+ global _default_handler
96
+
97
+ with _lock:
98
+ if not _default_handler:
99
+ return
100
+
101
+ library_root_logger = _get_library_root_logger()
102
+ library_root_logger.removeHandler(_default_handler)
103
+ library_root_logger.setLevel(logging.NOTSET)
104
+ _default_handler = None
105
+
106
+
107
+ def get_log_levels_dict():
108
+ return log_levels
109
+
110
+
111
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
112
+ """
113
+ Return a logger with the specified name.
114
+
115
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
116
+ """
117
+
118
+ if name is None:
119
+ name = _get_library_name()
120
+
121
+ _configure_library_root_logger()
122
+ return logging.getLogger(name)
123
+
124
+
125
+ def get_verbosity() -> int:
126
+ """
127
+ Return the current level for the 🤗 Diffusers' root logger as an int.
128
+
129
+ Returns:
130
+ `int`: The logging level.
131
+
132
+ <Tip>
133
+
134
+ 🤗 Diffusers has following logging levels:
135
+
136
+ - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
137
+ - 40: `diffusers.logging.ERROR`
138
+ - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
139
+ - 20: `diffusers.logging.INFO`
140
+ - 10: `diffusers.logging.DEBUG`
141
+
142
+ </Tip>"""
143
+
144
+ _configure_library_root_logger()
145
+ return _get_library_root_logger().getEffectiveLevel()
146
+
147
+
148
+ def set_verbosity(verbosity: int) -> None:
149
+ """
150
+ Set the verbosity level for the 🤗 Diffusers' root logger.
151
+
152
+ Args:
153
+ verbosity (`int`):
154
+ Logging level, e.g., one of:
155
+
156
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
157
+ - `diffusers.logging.ERROR`
158
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
159
+ - `diffusers.logging.INFO`
160
+ - `diffusers.logging.DEBUG`
161
+ """
162
+
163
+ _configure_library_root_logger()
164
+ _get_library_root_logger().setLevel(verbosity)
165
+
166
+
167
+ def set_verbosity_info():
168
+ """Set the verbosity to the `INFO` level."""
169
+ return set_verbosity(INFO)
170
+
171
+
172
+ def set_verbosity_warning():
173
+ """Set the verbosity to the `WARNING` level."""
174
+ return set_verbosity(WARNING)
175
+
176
+
177
+ def set_verbosity_debug():
178
+ """Set the verbosity to the `DEBUG` level."""
179
+ return set_verbosity(DEBUG)
180
+
181
+
182
+ def set_verbosity_error():
183
+ """Set the verbosity to the `ERROR` level."""
184
+ return set_verbosity(ERROR)
185
+
186
+
187
+ def disable_default_handler() -> None:
188
+ """Disable the default handler of the HuggingFace Diffusers' root logger."""
189
+
190
+ _configure_library_root_logger()
191
+
192
+ assert _default_handler is not None
193
+ _get_library_root_logger().removeHandler(_default_handler)
194
+
195
+
196
+ def enable_default_handler() -> None:
197
+ """Enable the default handler of the HuggingFace Diffusers' root logger."""
198
+
199
+ _configure_library_root_logger()
200
+
201
+ assert _default_handler is not None
202
+ _get_library_root_logger().addHandler(_default_handler)
203
+
204
+
205
+ def add_handler(handler: logging.Handler) -> None:
206
+ """adds a handler to the HuggingFace Diffusers' root logger."""
207
+
208
+ _configure_library_root_logger()
209
+
210
+ assert handler is not None
211
+ _get_library_root_logger().addHandler(handler)
212
+
213
+
214
+ def remove_handler(handler: logging.Handler) -> None:
215
+ """removes given handler from the HuggingFace Diffusers' root logger."""
216
+
217
+ _configure_library_root_logger()
218
+
219
+ assert handler is not None and handler not in _get_library_root_logger().handlers
220
+ _get_library_root_logger().removeHandler(handler)
221
+
222
+
223
+ def disable_propagation() -> None:
224
+ """
225
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
226
+ """
227
+
228
+ _configure_library_root_logger()
229
+ _get_library_root_logger().propagate = False
230
+
231
+
232
+ def enable_propagation() -> None:
233
+ """
234
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
235
+ double logging if the root logger has been configured.
236
+ """
237
+
238
+ _configure_library_root_logger()
239
+ _get_library_root_logger().propagate = True
240
+
241
+
242
+ def enable_explicit_format() -> None:
243
+ """
244
+ Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
245
+ ```
246
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
247
+ ```
248
+ All handlers currently bound to the root logger are affected by this method.
249
+ """
250
+ handlers = _get_library_root_logger().handlers
251
+
252
+ for handler in handlers:
253
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
254
+ handler.setFormatter(formatter)
255
+
256
+
257
+ def reset_format() -> None:
258
+ """
259
+ Resets the formatting for HuggingFace Diffusers' loggers.
260
+
261
+ All handlers currently bound to the root logger are affected by this method.
262
+ """
263
+ handlers = _get_library_root_logger().handlers
264
+
265
+ for handler in handlers:
266
+ handler.setFormatter(None)
267
+
268
+
269
+ def warning_advice(self, *args, **kwargs):
270
+ """
271
+ This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
272
+ warning will not be printed
273
+ """
274
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
275
+ if no_advisory_warnings:
276
+ return
277
+ self.warning(*args, **kwargs)
278
+
279
+
280
+ logging.Logger.warning_advice = warning_advice
281
+
282
+
283
+ class EmptyTqdm:
284
+ """Dummy tqdm which doesn't do anything."""
285
+
286
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
287
+ self._iterator = args[0] if args else None
288
+
289
+ def __iter__(self):
290
+ return iter(self._iterator)
291
+
292
+ def __getattr__(self, _):
293
+ """Return empty function."""
294
+
295
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
296
+ return
297
+
298
+ return empty_fn
299
+
300
+ def __enter__(self):
301
+ return self
302
+
303
+ def __exit__(self, type_, value, traceback):
304
+ return
305
+
306
+
307
+ class _tqdm_cls:
308
+ def __call__(self, *args, **kwargs):
309
+ if _tqdm_active:
310
+ return tqdm_lib.tqdm(*args, **kwargs)
311
+ else:
312
+ return EmptyTqdm(*args, **kwargs)
313
+
314
+ def set_lock(self, *args, **kwargs):
315
+ self._lock = None
316
+ if _tqdm_active:
317
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
318
+
319
+ def get_lock(self):
320
+ if _tqdm_active:
321
+ return tqdm_lib.tqdm.get_lock()
322
+
323
+
324
+ tqdm = _tqdm_cls()
325
+
326
+
327
+ def is_progress_bar_enabled() -> bool:
328
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
329
+ global _tqdm_active
330
+ return bool(_tqdm_active)
331
+
332
+
333
+ def enable_progress_bar():
334
+ """Enable tqdm progress bar."""
335
+ global _tqdm_active
336
+ _tqdm_active = True
337
+
338
+
339
+ def disable_progress_bar():
340
+ """Disable tqdm progress bar."""
341
+ global _tqdm_active
342
+ _tqdm_active = False
diffusers/utils/outputs.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Generic utilities
16
+ """
17
+
18
+ from collections import OrderedDict
19
+ from dataclasses import fields
20
+ from typing import Any, Tuple
21
+
22
+ import numpy as np
23
+
24
+ from .import_utils import is_torch_available
25
+
26
+
27
+ def is_tensor(x):
28
+ """
29
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
30
+ """
31
+ if is_torch_available():
32
+ import torch
33
+
34
+ if isinstance(x, torch.Tensor):
35
+ return True
36
+
37
+ return isinstance(x, np.ndarray)
38
+
39
+
40
+ class BaseOutput(OrderedDict):
41
+ """
42
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
43
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
44
+ python dictionary.
45
+
46
+ <Tip warning={true}>
47
+
48
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
49
+ before.
50
+
51
+ </Tip>
52
+ """
53
+
54
+ def __post_init__(self):
55
+ class_fields = fields(self)
56
+
57
+ # Safety and consistency checks
58
+ if not len(class_fields):
59
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
60
+
61
+ first_field = getattr(self, class_fields[0].name)
62
+ other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
63
+
64
+ if other_fields_are_none and isinstance(first_field, dict):
65
+ for key, value in first_field.items():
66
+ self[key] = value
67
+ else:
68
+ for field in class_fields:
69
+ v = getattr(self, field.name)
70
+ if v is not None:
71
+ self[field.name] = v
72
+
73
+ def __delitem__(self, *args, **kwargs):
74
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
75
+
76
+ def setdefault(self, *args, **kwargs):
77
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
78
+
79
+ def pop(self, *args, **kwargs):
80
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
81
+
82
+ def update(self, *args, **kwargs):
83
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
84
+
85
+ def __getitem__(self, k):
86
+ if isinstance(k, str):
87
+ inner_dict = dict(self.items())
88
+ return inner_dict[k]
89
+ else:
90
+ return self.to_tuple()[k]
91
+
92
+ def __setattr__(self, name, value):
93
+ if name in self.keys() and value is not None:
94
+ # Don't call self.__setitem__ to avoid recursion errors
95
+ super().__setitem__(name, value)
96
+ super().__setattr__(name, value)
97
+
98
+ def __setitem__(self, key, value):
99
+ # Will raise a KeyException if needed
100
+ super().__setitem__(key, value)
101
+ # Don't call self.__setattr__ to avoid recursion errors
102
+ super().__setattr__(key, value)
103
+
104
+ def to_tuple(self) -> Tuple[Any]:
105
+ """
106
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
107
+ """
108
+ return tuple(self[k] for k in self.keys())
diffusers/utils/scheduling_utils.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import os
16
+ from dataclasses import dataclass
17
+ from enum import Enum
18
+ from typing import Any, Dict, Optional, Union
19
+
20
+ import torch
21
+
22
+ from .outputs import BaseOutput
23
+
24
+
25
+ SCHEDULER_CONFIG_NAME = "scheduler_config.json"
26
+
27
+
28
+ # NOTE: We make this type an enum because it simplifies usage in docs and prevents
29
+ # circular imports when used for `_compatibles` within the schedulers module.
30
+ # When it's used as a type in pipelines, it really is a Union because the actual
31
+ # scheduler instance is passed in.
32
+ class KarrasDiffusionSchedulers(Enum):
33
+ DDIMScheduler = 1
34
+ DDPMScheduler = 2
35
+ PNDMScheduler = 3
36
+ LMSDiscreteScheduler = 4
37
+ EulerDiscreteScheduler = 5
38
+ HeunDiscreteScheduler = 6
39
+ EulerAncestralDiscreteScheduler = 7
40
+ DPMSolverMultistepScheduler = 8
41
+ DPMSolverSinglestepScheduler = 9
42
+ KDPM2DiscreteScheduler = 10
43
+ KDPM2AncestralDiscreteScheduler = 11
44
+ DEISMultistepScheduler = 12
45
+ UniPCMultistepScheduler = 13
46
+
47
+
48
+ @dataclass
49
+ class SchedulerOutput(BaseOutput):
50
+ """
51
+ Base class for the scheduler's step function output.
52
+
53
+ Args:
54
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
55
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
56
+ denoising loop.
57
+ """
58
+
59
+ prev_sample: torch.FloatTensor
60
+
61
+
62
+ class SchedulerMixin:
63
+ """
64
+ Mixin containing common functions for the schedulers.
65
+
66
+ Class attributes:
67
+ - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
68
+ `from_config` can be used from a class different than the one used to save the config (should be overridden
69
+ by parent class).
70
+ """
71
+
72
+ config_name = SCHEDULER_CONFIG_NAME
73
+ _compatibles = []
74
+ has_compatibles = True
75
+
76
+ @classmethod
77
+ def from_pretrained(
78
+ cls,
79
+ pretrained_model_name_or_path: Dict[str, Any] = None,
80
+ subfolder: Optional[str] = None,
81
+ return_unused_kwargs=False,
82
+ **kwargs,
83
+ ):
84
+ r"""
85
+ Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
86
+
87
+ Parameters:
88
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
89
+ Can be either:
90
+
91
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
92
+ organization name, like `google/ddpm-celebahq-256`.
93
+ - A path to a *directory* containing the schedluer configurations saved using
94
+ [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
95
+ subfolder (`str`, *optional*):
96
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
97
+ huggingface.co or downloaded locally), you can specify the folder name here.
98
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
99
+ Whether kwargs that are not consumed by the Python class should be returned or not.
100
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
101
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
102
+ standard cache should not be used.
103
+ force_download (`bool`, *optional*, defaults to `False`):
104
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
105
+ cached versions if they exist.
106
+ resume_download (`bool`, *optional*, defaults to `False`):
107
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
108
+ file exists.
109
+ proxies (`Dict[str, str]`, *optional*):
110
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
111
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
112
+ output_loading_info(`bool`, *optional*, defaults to `False`):
113
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
114
+ local_files_only(`bool`, *optional*, defaults to `False`):
115
+ Whether or not to only look at local files (i.e., do not try to download the model).
116
+ use_auth_token (`str` or *bool*, *optional*):
117
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
118
+ when running `transformers-cli login` (stored in `~/.huggingface`).
119
+ revision (`str`, *optional*, defaults to `"main"`):
120
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
121
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
122
+ identifier allowed by git.
123
+
124
+ <Tip>
125
+
126
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
127
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
128
+
129
+ </Tip>
130
+
131
+ <Tip>
132
+
133
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
134
+ use this method in a firewalled environment.
135
+
136
+ </Tip>
137
+
138
+ """
139
+ config, kwargs, commit_hash = cls.load_config(
140
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
141
+ subfolder=subfolder,
142
+ return_unused_kwargs=True,
143
+ return_commit_hash=True,
144
+ **kwargs,
145
+ )
146
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
147
+
148
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
149
+ """
150
+ Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
151
+ [`~SchedulerMixin.from_pretrained`] class method.
152
+
153
+ Args:
154
+ save_directory (`str` or `os.PathLike`):
155
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
156
+ """
157
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
158
+
159
+ @property
160
+ def compatibles(self):
161
+ """
162
+ Returns all schedulers that are compatible with this scheduler
163
+
164
+ Returns:
165
+ `List[SchedulerMixin]`: List of compatible schedulers
166
+ """
167
+ return self._get_compatibles()
168
+
169
+ @classmethod
170
+ def _get_compatibles(cls):
171
+ compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
172
+ diffusers_library = importlib.import_module(__name__.split(".")[0])
173
+ compatible_classes = [
174
+ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
175
+ ]
176
+ return compatible_classes
diffusers/utils/torch_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ PyTorch utilities: Utilities related to PyTorch
16
+ """
17
+ from typing import List, Optional, Tuple, Union
18
+
19
+ from . import logging
20
+ from .import_utils import is_torch_available, is_torch_version
21
+
22
+ if is_torch_available():
23
+ import torch
24
+
25
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26
+
27
+ try:
28
+ from torch._dynamo import allow_in_graph as maybe_allow_in_graph
29
+ except (ImportError, ModuleNotFoundError):
30
+
31
+ def maybe_allow_in_graph(cls):
32
+ return cls
33
+
34
+
35
+ def randn_tensor(
36
+ shape: Union[Tuple, List],
37
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
38
+ device: Optional["torch.device"] = None,
39
+ dtype: Optional["torch.dtype"] = None,
40
+ layout: Optional["torch.layout"] = None,
41
+ ):
42
+ """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
43
+ passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
44
+ will always be created on CPU.
45
+ """
46
+ # device on which tensor is created defaults to device
47
+ rand_device = device
48
+ batch_size = shape[0]
49
+
50
+ layout = layout or torch.strided
51
+ device = device or torch.device("cpu")
52
+
53
+ if generator is not None:
54
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
55
+ if gen_device_type != device.type and gen_device_type == "cpu":
56
+ rand_device = "cpu"
57
+ if device != "mps":
58
+ logger.info(
59
+ f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
60
+ f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
61
+ f" slighly speed up this function by passing a generator that was created on the {device} device."
62
+ )
63
+ elif gen_device_type != device.type and gen_device_type == "cuda":
64
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
65
+
66
+ if isinstance(generator, list):
67
+ shape = (1,) + shape[1:]
68
+ latents = [
69
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
70
+ for i in range(batch_size)
71
+ ]
72
+ latents = torch.cat(latents, dim=0).to(device)
73
+ else:
74
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
75
+
76
+ return latents
77
+
78
+
79
+ def is_compiled_module(module):
80
+ """Check whether the module was compiled with torch.compile()"""
81
+ if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
82
+ return False
83
+ return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
run_gradio.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import soundfile as sf
4
+ import numpy as np
5
+ import random, os
6
+
7
+ from consistencytta import ConsistencyTTA
8
+
9
+
10
+ def seed_all(seed):
11
+ """ Seed all random number generators. """
12
+ seed = int(seed)
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ torch.cuda.random.manual_seed(seed)
19
+ os.environ['PYTHONHASHSEED'] = str(seed)
20
+ torch.backends.cudnn.benchmark = False
21
+ torch.backends.cudnn.deterministic = True
22
+
23
+
24
+ device = torch.device(
25
+ "cuda:0" if torch.cuda.is_available() else
26
+ "mps" if torch.backends.mps.is_available() else "cpu"
27
+ )
28
+ sr = 16000
29
+
30
+ # Build ConsistencyTTA model
31
+ consistencytta = ConsistencyTTA().to(device)
32
+ consistencytta.eval()
33
+ consistencytta.requires_grad_(False)
34
+
35
+
36
+ def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
37
+ """ Generate audio from a given prompt.
38
+ Args:
39
+ prompt (str): Text prompt to generate audio from.
40
+ seed (str, optional): Random seed. Defaults to '', which means no seed.
41
+ """
42
+ if seed != '':
43
+ try:
44
+ seed_all(int(seed))
45
+ except:
46
+ pass
47
+
48
+ with torch.no_grad():
49
+ with torch.autocast(
50
+ device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()
51
+ ):
52
+ wav = consistencytta(
53
+ [prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
54
+ )
55
+ sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')
56
+
57
+ return "output.wav"
58
+
59
+
60
+ # Generate test audio
61
+ print("Generating test audio...")
62
+ generate("A dog barks as a train passes by.", seed=1)
63
+ print("Test audio generated successfully! Starting Gradio interface...")
64
+
65
+ # Launch Gradio interface
66
+ iface = gr.Interface(
67
+ fn=generate,
68
+ inputs=[
69
+ gr.Textbox(
70
+ label="Text", value="Several people cheer and scream and speak as water flows hard."
71
+ ),
72
+ gr.Textbox(label="Random Seed (Optional)", value=''),
73
+ gr.Slider(
74
+ minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
75
+ )],
76
+ outputs="audio",
77
+ title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
78
+ "Generation with Consistency Distillation",
79
+ description="This is the official demo page for <a href='https://consistency-tta.github." \
80
+ "io' target=&ldquo;blank&rdquo;>ConsistencyTTA</a>, a model that accelerates " \
81
+ "diffusion-based text-to-audio generation hundreds of times with consistency " \
82
+ "models. <br> Here, the audio is generated within a single non-autoregressive " \
83
+ "forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
84
+ "the training dataset does not include speech, the model is not expected to " \
85
+ "generate coherent speech. <br> Have fun!"
86
+ )
87
+ iface.launch(share=True)
tango_diffusion_light.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.10.0.dev0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": [
6
+ 5,
7
+ 10,
8
+ 20,
9
+ 20
10
+ ],
11
+ "block_out_channels": [
12
+ 256,
13
+ 512,
14
+ 1024,
15
+ 1024
16
+ ],
17
+ "center_input_sample": false,
18
+ "cross_attention_dim": 1024,
19
+ "down_block_types": [
20
+ "CrossAttnDownBlock2D",
21
+ "CrossAttnDownBlock2D",
22
+ "CrossAttnDownBlock2D",
23
+ "DownBlock2D"
24
+ ],
25
+ "downsample_padding": 1,
26
+ "dual_cross_attention": false,
27
+ "flip_sin_to_cos": true,
28
+ "freq_shift": 0,
29
+ "in_channels": 8,
30
+ "layers_per_block": 2,
31
+ "mid_block_scale_factor": 1,
32
+ "norm_eps": 1e-05,
33
+ "norm_num_groups": 32,
34
+ "num_class_embeds": null,
35
+ "only_cross_attention": false,
36
+ "out_channels": 8,
37
+ "sample_size": [32, 2],
38
+ "up_block_types": [
39
+ "UpBlock2D",
40
+ "CrossAttnUpBlock2D",
41
+ "CrossAttnUpBlock2D",
42
+ "CrossAttnUpBlock2D"
43
+ ],
44
+ "use_linear_projection": true,
45
+ "upcast_attention": true
46
+ }