Spaces:
Running
on
Zero
Running
on
Zero
Bai-YT
commited on
Commit
•
66982e9
0
Parent(s):
Gradio App for ConsistencyTTA V1
Browse files- .gitignore +6 -0
- README.md +7 -0
- audioldm/hifigan/__init__.py +7 -0
- audioldm/hifigan/models.py +127 -0
- audioldm/hifigan/utilities.py +88 -0
- audioldm/latent_diffusion/attention.py +469 -0
- audioldm/latent_diffusion/util.py +293 -0
- audioldm/stft.py +257 -0
- audioldm/utils.py +177 -0
- audioldm/variational_autoencoder/__init__.py +1 -0
- audioldm/variational_autoencoder/autoencoder.py +131 -0
- audioldm/variational_autoencoder/distributions.py +102 -0
- audioldm/variational_autoencoder/modules.py +1067 -0
- consistencytta.py +200 -0
- consistencytta_clapft_ckpt/.DS_Store +0 -0
- diffusers/__init__.py +2 -0
- diffusers/models/__init__.py +23 -0
- diffusers/models/activations.py +12 -0
- diffusers/models/attention.py +523 -0
- diffusers/models/attention_processor.py +1646 -0
- diffusers/models/dual_transformer_2d.py +151 -0
- diffusers/models/embeddings.py +480 -0
- diffusers/models/loaders.py +1481 -0
- diffusers/models/modeling_utils.py +978 -0
- diffusers/models/prior_transformer.py +194 -0
- diffusers/models/resnet.py +839 -0
- diffusers/models/transformer_2d.py +333 -0
- diffusers/models/unet_2d.py +315 -0
- diffusers/models/unet_2d_blocks.py +0 -0
- diffusers/models/unet_2d_condition.py +907 -0
- diffusers/models/unet_2d_condition_guided.py +945 -0
- diffusers/scheduling_heun_discrete.py +387 -0
- diffusers/utils/configuration_utils.py +647 -0
- diffusers/utils/constants.py +34 -0
- diffusers/utils/deprecation_utils.py +49 -0
- diffusers/utils/hub_utils.py +357 -0
- diffusers/utils/import_utils.py +649 -0
- diffusers/utils/logging.py +342 -0
- diffusers/utils/outputs.py +108 -0
- diffusers/utils/scheduling_utils.py +176 -0
- diffusers/utils/torch_utils.py +83 -0
- run_gradio.py +87 -0
- tango_diffusion_light.json +46 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*/__pycache__
|
3 |
+
flagged
|
4 |
+
*.wav
|
5 |
+
*.pt
|
6 |
+
*.DS_Store
|
README.md
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Gradio App for ConsistencyTTA
|
2 |
+
|
3 |
+
Required packages:
|
4 |
+
`numpy scipy torch torchaudio einops soundfile librosa transformers gradio`
|
5 |
+
|
6 |
+
To run:
|
7 |
+
`python run_gradio.py`
|
audioldm/hifigan/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import Generator
|
2 |
+
|
3 |
+
|
4 |
+
class AttrDict(dict):
|
5 |
+
def __init__(self, *args, **kwargs):
|
6 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
7 |
+
self.__dict__ = self
|
audioldm/hifigan/models.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils.parametrizations import weight_norm
|
6 |
+
from torch.nn.utils.parametrize import remove_parametrizations
|
7 |
+
|
8 |
+
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
|
12 |
+
def init_weights(m, mean=0.0, std=0.01):
|
13 |
+
classname = m.__class__.__name__
|
14 |
+
if classname.find("Conv") != -1:
|
15 |
+
m.weight.data.normal_(mean, std)
|
16 |
+
|
17 |
+
|
18 |
+
def get_padding(kernel_size, dilation=1):
|
19 |
+
return int((kernel_size * dilation - dilation) / 2)
|
20 |
+
|
21 |
+
|
22 |
+
class ResBlock(torch.nn.Module):
|
23 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
24 |
+
super(ResBlock, self).__init__()
|
25 |
+
self.h = h
|
26 |
+
self.convs1 = nn.ModuleList([
|
27 |
+
weight_norm(Conv1d(
|
28 |
+
channels, channels, kernel_size, 1, dilation=dilation[0],
|
29 |
+
padding=get_padding(kernel_size, dilation[0]),
|
30 |
+
)),
|
31 |
+
weight_norm(Conv1d(
|
32 |
+
channels, channels, kernel_size, 1, dilation=dilation[1],
|
33 |
+
padding=get_padding(kernel_size, dilation[1]),
|
34 |
+
)),
|
35 |
+
weight_norm(Conv1d(
|
36 |
+
channels, channels, kernel_size, 1, dilation=dilation[2],
|
37 |
+
padding=get_padding(kernel_size, dilation[2]),
|
38 |
+
)),
|
39 |
+
])
|
40 |
+
self.convs1.apply(init_weights)
|
41 |
+
|
42 |
+
self.convs2 = nn.ModuleList([
|
43 |
+
weight_norm(Conv1d(
|
44 |
+
channels, channels, kernel_size, 1, dilation=1,
|
45 |
+
padding=get_padding(kernel_size, 1),
|
46 |
+
)),
|
47 |
+
weight_norm(Conv1d(
|
48 |
+
channels, channels, kernel_size, 1, dilation=1,
|
49 |
+
padding=get_padding(kernel_size, 1),
|
50 |
+
)),
|
51 |
+
weight_norm(Conv1d(
|
52 |
+
channels, channels, kernel_size, 1, dilation=1,
|
53 |
+
padding=get_padding(kernel_size, 1),
|
54 |
+
)),
|
55 |
+
])
|
56 |
+
self.convs2.apply(init_weights)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
60 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
61 |
+
xt = c1(xt)
|
62 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
63 |
+
xt = c2(xt)
|
64 |
+
x = xt + x
|
65 |
+
return x
|
66 |
+
|
67 |
+
def remove_weight_norm(self):
|
68 |
+
for l in self.convs1:
|
69 |
+
remove_parametrizations(l, 'weight')
|
70 |
+
for l in self.convs2:
|
71 |
+
remove_parametrizations(l, 'weight')
|
72 |
+
|
73 |
+
|
74 |
+
class Generator(torch.nn.Module):
|
75 |
+
def __init__(self, h):
|
76 |
+
super(Generator, self).__init__()
|
77 |
+
self.h = h
|
78 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
79 |
+
self.num_upsamples = len(h.upsample_rates)
|
80 |
+
self.conv_pre = weight_norm(
|
81 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
82 |
+
)
|
83 |
+
resblock = ResBlock
|
84 |
+
|
85 |
+
self.ups = nn.ModuleList()
|
86 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
87 |
+
self.ups.append(weight_norm(ConvTranspose1d(
|
88 |
+
h.upsample_initial_channel // (2**i),
|
89 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
90 |
+
k, u, padding=(k - u) // 2,
|
91 |
+
)))
|
92 |
+
|
93 |
+
self.resblocks = nn.ModuleList()
|
94 |
+
for i in range(len(self.ups)):
|
95 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
96 |
+
for k, d in zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes):
|
97 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
98 |
+
|
99 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
100 |
+
self.ups.apply(init_weights)
|
101 |
+
self.conv_post.apply(init_weights)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = self.conv_pre(x)
|
105 |
+
for i in range(self.num_upsamples):
|
106 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
107 |
+
x = self.ups[i](x)
|
108 |
+
xs = None
|
109 |
+
for j in range(self.num_kernels):
|
110 |
+
if xs is None:
|
111 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
112 |
+
else:
|
113 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
114 |
+
x = xs / self.num_kernels
|
115 |
+
x = F.leaky_relu(x)
|
116 |
+
x = self.conv_post(x)
|
117 |
+
x = torch.tanh(x)
|
118 |
+
|
119 |
+
return x
|
120 |
+
|
121 |
+
def remove_weight_norm(self):
|
122 |
+
for l in self.ups:
|
123 |
+
remove_parametrizations(l, 'weight')
|
124 |
+
for l in self.resblocks:
|
125 |
+
l.remove_weight_norm()
|
126 |
+
remove_parametrizations(self.conv_pre, 'weight')
|
127 |
+
remove_parametrizations(self.conv_post, 'weight')
|
audioldm/hifigan/utilities.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import audioldm.hifigan as hifigan
|
4 |
+
|
5 |
+
|
6 |
+
HIFIGAN_16K_64 = {
|
7 |
+
"resblock": "1",
|
8 |
+
"num_gpus": 6,
|
9 |
+
"batch_size": 16,
|
10 |
+
"learning_rate": 0.0002,
|
11 |
+
"adam_b1": 0.8,
|
12 |
+
"adam_b2": 0.99,
|
13 |
+
"lr_decay": 0.999,
|
14 |
+
"seed": 1234,
|
15 |
+
"upsample_rates": [5, 4, 2, 2, 2],
|
16 |
+
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
17 |
+
"upsample_initial_channel": 1024,
|
18 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
19 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
20 |
+
"segment_size": 8192,
|
21 |
+
"num_mels": 64,
|
22 |
+
"num_freq": 1025,
|
23 |
+
"n_fft": 1024,
|
24 |
+
"hop_size": 160,
|
25 |
+
"win_size": 1024,
|
26 |
+
"sampling_rate": 16000,
|
27 |
+
"fmin": 0,
|
28 |
+
"fmax": 8000,
|
29 |
+
"fmax_for_loss": None,
|
30 |
+
"num_workers": 4,
|
31 |
+
"dist_config": {
|
32 |
+
"dist_backend": "nccl",
|
33 |
+
"dist_url": "tcp://localhost:54321",
|
34 |
+
"world_size": 1,
|
35 |
+
},
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def get_available_checkpoint_keys(model, ckpt):
|
40 |
+
print("==> Attemp to reload from %s" % ckpt)
|
41 |
+
state_dict = torch.load(ckpt)["state_dict"]
|
42 |
+
current_state_dict = model.state_dict()
|
43 |
+
new_state_dict = {}
|
44 |
+
for k in state_dict.keys():
|
45 |
+
if (
|
46 |
+
k in current_state_dict.keys()
|
47 |
+
and current_state_dict[k].size() == state_dict[k].size()
|
48 |
+
):
|
49 |
+
new_state_dict[k] = state_dict[k]
|
50 |
+
else:
|
51 |
+
print("==> WARNING: Skipping %s" % k)
|
52 |
+
print(
|
53 |
+
"%s out of %s keys are matched"
|
54 |
+
% (len(new_state_dict.keys()), len(state_dict.keys()))
|
55 |
+
)
|
56 |
+
return new_state_dict
|
57 |
+
|
58 |
+
|
59 |
+
def get_param_num(model):
|
60 |
+
num_param = sum(param.numel() for param in model.parameters())
|
61 |
+
return num_param
|
62 |
+
|
63 |
+
|
64 |
+
def get_vocoder(config, device):
|
65 |
+
config = hifigan.AttrDict(HIFIGAN_16K_64)
|
66 |
+
vocoder = hifigan.Generator(config)
|
67 |
+
vocoder.eval()
|
68 |
+
vocoder.remove_weight_norm()
|
69 |
+
vocoder.to(device)
|
70 |
+
return vocoder
|
71 |
+
|
72 |
+
|
73 |
+
def vocoder_infer(mels, vocoder, allow_grad=False, lengths=None):
|
74 |
+
vocoder.eval()
|
75 |
+
|
76 |
+
if allow_grad:
|
77 |
+
wavs = vocoder(mels).squeeze(1).float()
|
78 |
+
wavs = wavs - (wavs.max() + wavs.min()) / 2
|
79 |
+
else:
|
80 |
+
with torch.no_grad():
|
81 |
+
wavs = vocoder(mels).squeeze(1).float()
|
82 |
+
wavs = wavs - (wavs.max() + wavs.min()) / 2
|
83 |
+
wavs = (wavs.cpu().numpy() * 32768).astype("int16")
|
84 |
+
|
85 |
+
if lengths is not None:
|
86 |
+
wavs = wavs[:, :lengths]
|
87 |
+
|
88 |
+
return wavs
|
audioldm/latent_diffusion/attention.py
ADDED
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from inspect import isfunction
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from audioldm.latent_diffusion.util import checkpoint
|
9 |
+
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
|
15 |
+
def uniq(arr):
|
16 |
+
return {el: True for el in arr}.keys()
|
17 |
+
|
18 |
+
|
19 |
+
def default(val, d):
|
20 |
+
if exists(val):
|
21 |
+
return val
|
22 |
+
return d() if isfunction(d) else d
|
23 |
+
|
24 |
+
|
25 |
+
def max_neg_value(t):
|
26 |
+
return -torch.finfo(t.dtype).max
|
27 |
+
|
28 |
+
|
29 |
+
def init_(tensor):
|
30 |
+
dim = tensor.shape[-1]
|
31 |
+
std = 1 / math.sqrt(dim)
|
32 |
+
tensor.uniform_(-std, std)
|
33 |
+
return tensor
|
34 |
+
|
35 |
+
|
36 |
+
# feedforward
|
37 |
+
class GEGLU(nn.Module):
|
38 |
+
def __init__(self, dim_in, dim_out):
|
39 |
+
super().__init__()
|
40 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
44 |
+
return x * F.gelu(gate)
|
45 |
+
|
46 |
+
|
47 |
+
class FeedForward(nn.Module):
|
48 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
|
49 |
+
super().__init__()
|
50 |
+
inner_dim = int(dim * mult)
|
51 |
+
dim_out = default(dim_out, dim)
|
52 |
+
project_in = (
|
53 |
+
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
|
54 |
+
if not glu
|
55 |
+
else GEGLU(dim, inner_dim)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.net = nn.Sequential(
|
59 |
+
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
|
60 |
+
)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
return self.net(x)
|
64 |
+
|
65 |
+
|
66 |
+
def zero_module(module):
|
67 |
+
"""
|
68 |
+
Zero out the parameters of a module and return it.
|
69 |
+
"""
|
70 |
+
for p in module.parameters():
|
71 |
+
p.detach().zero_()
|
72 |
+
return module
|
73 |
+
|
74 |
+
|
75 |
+
def Normalize(in_channels):
|
76 |
+
return torch.nn.GroupNorm(
|
77 |
+
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
class LinearAttention(nn.Module):
|
82 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
83 |
+
super().__init__()
|
84 |
+
self.heads = heads
|
85 |
+
hidden_dim = dim_head * heads
|
86 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
87 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
b, c, h, w = x.shape
|
91 |
+
qkv = self.to_qkv(x)
|
92 |
+
q, k, v = rearrange(
|
93 |
+
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
|
94 |
+
)
|
95 |
+
k = k.softmax(dim=-1)
|
96 |
+
context = torch.einsum("bhdn,bhen->bhde", k, v)
|
97 |
+
out = torch.einsum("bhde,bhdn->bhen", context, q)
|
98 |
+
out = rearrange(
|
99 |
+
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
|
100 |
+
)
|
101 |
+
return self.to_out(out)
|
102 |
+
|
103 |
+
|
104 |
+
class SpatialSelfAttention(nn.Module):
|
105 |
+
def __init__(self, in_channels):
|
106 |
+
super().__init__()
|
107 |
+
self.in_channels = in_channels
|
108 |
+
|
109 |
+
self.norm = Normalize(in_channels)
|
110 |
+
self.q = torch.nn.Conv2d(
|
111 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
112 |
+
)
|
113 |
+
self.k = torch.nn.Conv2d(
|
114 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
115 |
+
)
|
116 |
+
self.v = torch.nn.Conv2d(
|
117 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
118 |
+
)
|
119 |
+
self.proj_out = torch.nn.Conv2d(
|
120 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
121 |
+
)
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
h_ = x
|
125 |
+
h_ = self.norm(h_)
|
126 |
+
q = self.q(h_)
|
127 |
+
k = self.k(h_)
|
128 |
+
v = self.v(h_)
|
129 |
+
|
130 |
+
# compute attention
|
131 |
+
b, c, h, w = q.shape
|
132 |
+
q = rearrange(q, "b c h w -> b (h w) c")
|
133 |
+
k = rearrange(k, "b c h w -> b c (h w)")
|
134 |
+
w_ = torch.einsum("bij,bjk->bik", q, k)
|
135 |
+
|
136 |
+
w_ = w_ * (int(c) ** (-0.5))
|
137 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
138 |
+
|
139 |
+
# attend to values
|
140 |
+
v = rearrange(v, "b c h w -> b c (h w)")
|
141 |
+
w_ = rearrange(w_, "b i j -> b j i")
|
142 |
+
h_ = torch.einsum("bij,bjk->bik", v, w_)
|
143 |
+
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
|
144 |
+
h_ = self.proj_out(h_)
|
145 |
+
|
146 |
+
return x + h_
|
147 |
+
|
148 |
+
|
149 |
+
class CrossAttention(nn.Module):
|
150 |
+
"""
|
151 |
+
### Cross Attention Layer
|
152 |
+
This falls-back to self-attention when conditional embeddings are not specified.
|
153 |
+
"""
|
154 |
+
|
155 |
+
# use_flash_attention: bool = True
|
156 |
+
use_flash_attention: bool = False
|
157 |
+
|
158 |
+
def __init__(
|
159 |
+
self,
|
160 |
+
query_dim,
|
161 |
+
context_dim=None,
|
162 |
+
heads=8,
|
163 |
+
dim_head=64,
|
164 |
+
dropout=0.0,
|
165 |
+
is_inplace: bool = True,
|
166 |
+
):
|
167 |
+
# def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
|
168 |
+
"""
|
169 |
+
:param d_model: is the input embedding size
|
170 |
+
:param n_heads: is the number of attention heads
|
171 |
+
:param d_head: is the size of a attention head
|
172 |
+
:param d_cond: is the size of the conditional embeddings
|
173 |
+
:param is_inplace: specifies whether to perform the attention softmax computation inplace to
|
174 |
+
save memory
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.is_inplace = is_inplace
|
179 |
+
self.n_heads = heads
|
180 |
+
self.d_head = dim_head
|
181 |
+
|
182 |
+
# Attention scaling factor
|
183 |
+
self.scale = dim_head**-0.5
|
184 |
+
|
185 |
+
# The normal self-attention layer
|
186 |
+
if context_dim is None:
|
187 |
+
context_dim = query_dim
|
188 |
+
|
189 |
+
# Query, key and value mappings
|
190 |
+
d_attn = dim_head * heads
|
191 |
+
self.to_q = nn.Linear(query_dim, d_attn, bias=False)
|
192 |
+
self.to_k = nn.Linear(context_dim, d_attn, bias=False)
|
193 |
+
self.to_v = nn.Linear(context_dim, d_attn, bias=False)
|
194 |
+
|
195 |
+
# Final linear layer
|
196 |
+
self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
|
197 |
+
|
198 |
+
# Setup [flash attention](https://github.com/HazyResearch/flash-attention).
|
199 |
+
# Flash attention is only used if it's installed
|
200 |
+
# and `CrossAttention.use_flash_attention` is set to `True`.
|
201 |
+
try:
|
202 |
+
# You can install flash attention by cloning their Github repo,
|
203 |
+
# [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
|
204 |
+
# and then running `python setup.py install`
|
205 |
+
from flash_attn.flash_attention import FlashAttention
|
206 |
+
|
207 |
+
self.flash = FlashAttention()
|
208 |
+
# Set the scale for scaled dot-product attention.
|
209 |
+
self.flash.softmax_scale = self.scale
|
210 |
+
# Set to `None` if it's not installed
|
211 |
+
except ImportError:
|
212 |
+
self.flash = None
|
213 |
+
|
214 |
+
def forward(self, x, context=None, mask=None):
|
215 |
+
"""
|
216 |
+
:param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
|
217 |
+
:param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
|
218 |
+
"""
|
219 |
+
|
220 |
+
# If `cond` is `None` we perform self attention
|
221 |
+
has_cond = context is not None
|
222 |
+
if not has_cond:
|
223 |
+
context = x
|
224 |
+
|
225 |
+
# Get query, key and value vectors
|
226 |
+
q = self.to_q(x)
|
227 |
+
k = self.to_k(context)
|
228 |
+
v = self.to_v(context)
|
229 |
+
|
230 |
+
# Use flash attention if it's available and the head size is less than or equal to `128`
|
231 |
+
if (
|
232 |
+
CrossAttention.use_flash_attention
|
233 |
+
and self.flash is not None
|
234 |
+
and not has_cond
|
235 |
+
and self.d_head <= 128
|
236 |
+
):
|
237 |
+
return self.flash_attention(q, k, v)
|
238 |
+
# Otherwise, fallback to normal attention
|
239 |
+
else:
|
240 |
+
return self.normal_attention(q, k, v)
|
241 |
+
|
242 |
+
def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
243 |
+
"""
|
244 |
+
#### Flash Attention
|
245 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
246 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
247 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
248 |
+
"""
|
249 |
+
|
250 |
+
# Get batch size and number of elements along sequence axis (`width * height`)
|
251 |
+
batch_size, seq_len, _ = q.shape
|
252 |
+
|
253 |
+
# Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
|
254 |
+
# shape `[batch_size, seq_len, 3, n_heads * d_head]`
|
255 |
+
qkv = torch.stack((q, k, v), dim=2)
|
256 |
+
# Split the heads
|
257 |
+
qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
|
258 |
+
|
259 |
+
# Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
|
260 |
+
# fit this size.
|
261 |
+
if self.d_head <= 32:
|
262 |
+
pad = 32 - self.d_head
|
263 |
+
elif self.d_head <= 64:
|
264 |
+
pad = 64 - self.d_head
|
265 |
+
elif self.d_head <= 128:
|
266 |
+
pad = 128 - self.d_head
|
267 |
+
else:
|
268 |
+
raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
|
269 |
+
|
270 |
+
# Pad the heads
|
271 |
+
if pad:
|
272 |
+
qkv = torch.cat(
|
273 |
+
(qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
|
274 |
+
)
|
275 |
+
|
276 |
+
# Compute attention
|
277 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
278 |
+
# This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
|
279 |
+
# TODO here I add the dtype changing
|
280 |
+
out, _ = self.flash(qkv.type(torch.float16))
|
281 |
+
# Truncate the extra head size
|
282 |
+
out = out[:, :, :, : self.d_head].float()
|
283 |
+
# Reshape to `[batch_size, seq_len, n_heads * d_head]`
|
284 |
+
out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
|
285 |
+
|
286 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
287 |
+
return self.to_out(out)
|
288 |
+
|
289 |
+
def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
290 |
+
"""
|
291 |
+
#### Normal Attention
|
292 |
+
|
293 |
+
:param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
294 |
+
:param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
295 |
+
:param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
|
296 |
+
"""
|
297 |
+
|
298 |
+
# Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
|
299 |
+
q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
|
300 |
+
k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
|
301 |
+
v = v.view(*v.shape[:2], self.n_heads, -1)
|
302 |
+
|
303 |
+
# Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
|
304 |
+
attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
|
305 |
+
|
306 |
+
# Compute softmax
|
307 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
|
308 |
+
if self.is_inplace:
|
309 |
+
half = attn.shape[0] // 2
|
310 |
+
attn[half:] = attn[half:].softmax(dim=-1)
|
311 |
+
attn[:half] = attn[:half].softmax(dim=-1)
|
312 |
+
else:
|
313 |
+
attn = attn.softmax(dim=-1)
|
314 |
+
|
315 |
+
# Compute attention output
|
316 |
+
# $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
|
317 |
+
# attn: [bs, 20, 64, 1]
|
318 |
+
# v: [bs, 1, 20, 32]
|
319 |
+
out = torch.einsum("bhij,bjhd->bihd", attn, v)
|
320 |
+
# Reshape to `[batch_size, height * width, n_heads * d_head]`
|
321 |
+
out = out.reshape(*out.shape[:2], -1)
|
322 |
+
# Map to `[batch_size, height * width, d_model]` with a linear layer
|
323 |
+
return self.to_out(out)
|
324 |
+
|
325 |
+
|
326 |
+
# class CrossAttention(nn.Module):
|
327 |
+
# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
328 |
+
# super().__init__()
|
329 |
+
# inner_dim = dim_head * heads
|
330 |
+
# context_dim = default(context_dim, query_dim)
|
331 |
+
|
332 |
+
# self.scale = dim_head ** -0.5
|
333 |
+
# self.heads = heads
|
334 |
+
|
335 |
+
# self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
336 |
+
# self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
337 |
+
# self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
338 |
+
|
339 |
+
# self.to_out = nn.Sequential(
|
340 |
+
# nn.Linear(inner_dim, query_dim),
|
341 |
+
# nn.Dropout(dropout)
|
342 |
+
# )
|
343 |
+
|
344 |
+
# def forward(self, x, context=None, mask=None):
|
345 |
+
# h = self.heads
|
346 |
+
|
347 |
+
# q = self.to_q(x)
|
348 |
+
# context = default(context, x)
|
349 |
+
# k = self.to_k(context)
|
350 |
+
# v = self.to_v(context)
|
351 |
+
|
352 |
+
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
353 |
+
|
354 |
+
# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
355 |
+
|
356 |
+
# if exists(mask):
|
357 |
+
# mask = rearrange(mask, 'b ... -> b (...)')
|
358 |
+
# max_neg_value = -torch.finfo(sim.dtype).max
|
359 |
+
# mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
360 |
+
# sim.masked_fill_(~mask, max_neg_value)
|
361 |
+
|
362 |
+
# # attention, what we cannot get enough of
|
363 |
+
# attn = sim.softmax(dim=-1)
|
364 |
+
|
365 |
+
# out = einsum('b i j, b j d -> b i d', attn, v)
|
366 |
+
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
367 |
+
# return self.to_out(out)
|
368 |
+
|
369 |
+
|
370 |
+
class BasicTransformerBlock(nn.Module):
|
371 |
+
def __init__(
|
372 |
+
self,
|
373 |
+
dim,
|
374 |
+
n_heads,
|
375 |
+
d_head,
|
376 |
+
dropout=0.0,
|
377 |
+
context_dim=None,
|
378 |
+
gated_ff=True,
|
379 |
+
checkpoint=True,
|
380 |
+
):
|
381 |
+
super().__init__()
|
382 |
+
self.attn1 = CrossAttention(
|
383 |
+
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
384 |
+
) # is a self-attention
|
385 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
386 |
+
self.attn2 = CrossAttention(
|
387 |
+
query_dim=dim,
|
388 |
+
context_dim=context_dim,
|
389 |
+
heads=n_heads,
|
390 |
+
dim_head=d_head,
|
391 |
+
dropout=dropout,
|
392 |
+
) # is self-attn if context is none
|
393 |
+
self.norm1 = nn.LayerNorm(dim)
|
394 |
+
self.norm2 = nn.LayerNorm(dim)
|
395 |
+
self.norm3 = nn.LayerNorm(dim)
|
396 |
+
self.checkpoint = checkpoint
|
397 |
+
|
398 |
+
def forward(self, x, context=None):
|
399 |
+
if context is None:
|
400 |
+
return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
|
401 |
+
else:
|
402 |
+
return checkpoint(
|
403 |
+
self._forward, (x, context), self.parameters(), self.checkpoint
|
404 |
+
)
|
405 |
+
|
406 |
+
def _forward(self, x, context=None):
|
407 |
+
x = self.attn1(self.norm1(x)) + x
|
408 |
+
x = self.attn2(self.norm2(x), context=context) + x
|
409 |
+
x = self.ff(self.norm3(x)) + x
|
410 |
+
return x
|
411 |
+
|
412 |
+
|
413 |
+
class SpatialTransformer(nn.Module):
|
414 |
+
"""
|
415 |
+
Transformer block for image-like data.
|
416 |
+
First, project the input (aka embedding)
|
417 |
+
and reshape to b, t, d.
|
418 |
+
Then apply standard transformer action.
|
419 |
+
Finally, reshape to image
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
in_channels,
|
425 |
+
n_heads,
|
426 |
+
d_head,
|
427 |
+
depth=1,
|
428 |
+
dropout=0.0,
|
429 |
+
context_dim=None,
|
430 |
+
no_context=False,
|
431 |
+
):
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
if no_context:
|
435 |
+
context_dim = None
|
436 |
+
|
437 |
+
self.in_channels = in_channels
|
438 |
+
inner_dim = n_heads * d_head
|
439 |
+
self.norm = Normalize(in_channels)
|
440 |
+
|
441 |
+
self.proj_in = nn.Conv2d(
|
442 |
+
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
|
443 |
+
)
|
444 |
+
|
445 |
+
self.transformer_blocks = nn.ModuleList(
|
446 |
+
[
|
447 |
+
BasicTransformerBlock(
|
448 |
+
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
|
449 |
+
)
|
450 |
+
for d in range(depth)
|
451 |
+
]
|
452 |
+
)
|
453 |
+
|
454 |
+
self.proj_out = zero_module(
|
455 |
+
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
456 |
+
)
|
457 |
+
|
458 |
+
def forward(self, x, context=None):
|
459 |
+
# note: if no context is given, cross-attention defaults to self-attention
|
460 |
+
b, c, h, w = x.shape
|
461 |
+
x_in = x
|
462 |
+
x = self.norm(x)
|
463 |
+
x = self.proj_in(x)
|
464 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
465 |
+
for block in self.transformer_blocks:
|
466 |
+
x = block(x, context=context)
|
467 |
+
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
|
468 |
+
x = self.proj_out(x)
|
469 |
+
return x + x_in
|
audioldm/latent_diffusion/util.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from
|
2 |
+
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
3 |
+
# and
|
4 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
5 |
+
# and
|
6 |
+
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
7 |
+
#
|
8 |
+
# thanks!
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import numpy as np
|
14 |
+
from einops import repeat
|
15 |
+
|
16 |
+
from audioldm.utils import instantiate_from_config
|
17 |
+
|
18 |
+
|
19 |
+
def make_beta_schedule(
|
20 |
+
schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
21 |
+
):
|
22 |
+
if schedule == "linear":
|
23 |
+
betas = (
|
24 |
+
torch.linspace(
|
25 |
+
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
26 |
+
)
|
27 |
+
** 2
|
28 |
+
)
|
29 |
+
|
30 |
+
elif schedule == "cosine":
|
31 |
+
timesteps = (
|
32 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
33 |
+
)
|
34 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
35 |
+
alphas = torch.cos(alphas).pow(2)
|
36 |
+
alphas = alphas / alphas[0]
|
37 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
38 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
39 |
+
|
40 |
+
elif schedule == "sqrt_linear":
|
41 |
+
betas = torch.linspace(
|
42 |
+
linear_start, linear_end, n_timestep, dtype=torch.float64
|
43 |
+
)
|
44 |
+
elif schedule == "sqrt":
|
45 |
+
betas = (
|
46 |
+
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
47 |
+
** 0.5
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
51 |
+
return betas.numpy()
|
52 |
+
|
53 |
+
|
54 |
+
def make_ddim_timesteps(
|
55 |
+
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
56 |
+
):
|
57 |
+
if ddim_discr_method == "uniform":
|
58 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
59 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
60 |
+
elif ddim_discr_method == "quad":
|
61 |
+
ddim_timesteps = (
|
62 |
+
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
63 |
+
).astype(int)
|
64 |
+
else:
|
65 |
+
raise NotImplementedError(
|
66 |
+
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
67 |
+
)
|
68 |
+
|
69 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
70 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
71 |
+
steps_out = ddim_timesteps + 1
|
72 |
+
if verbose:
|
73 |
+
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
74 |
+
return steps_out
|
75 |
+
|
76 |
+
|
77 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
78 |
+
# select alphas for computing the variance schedule
|
79 |
+
alphas = alphacums[ddim_timesteps]
|
80 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
81 |
+
|
82 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
83 |
+
sigmas = eta * np.sqrt(
|
84 |
+
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
85 |
+
)
|
86 |
+
if verbose:
|
87 |
+
print(
|
88 |
+
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
89 |
+
)
|
90 |
+
print(
|
91 |
+
f"For the chosen value of eta, which is {eta}, "
|
92 |
+
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
93 |
+
)
|
94 |
+
return sigmas, alphas, alphas_prev
|
95 |
+
|
96 |
+
|
97 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
98 |
+
"""
|
99 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
100 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
101 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
102 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
103 |
+
produces the cumulative product of (1-beta) up to that
|
104 |
+
part of the diffusion process.
|
105 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
106 |
+
prevent singularities.
|
107 |
+
"""
|
108 |
+
betas = []
|
109 |
+
for i in range(num_diffusion_timesteps):
|
110 |
+
t1 = i / num_diffusion_timesteps
|
111 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
112 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
113 |
+
return np.array(betas)
|
114 |
+
|
115 |
+
|
116 |
+
def extract_into_tensor(a, t, x_shape):
|
117 |
+
b, *_ = t.shape
|
118 |
+
out = a.gather(-1, t).contiguous()
|
119 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
|
120 |
+
|
121 |
+
|
122 |
+
def checkpoint(func, inputs, params, flag):
|
123 |
+
"""
|
124 |
+
Evaluate a function without caching intermediate activations, allowing for
|
125 |
+
reduced memory at the expense of extra compute in the backward pass.
|
126 |
+
:param func: the function to evaluate.
|
127 |
+
:param inputs: the argument sequence to pass to `func`.
|
128 |
+
:param params: a sequence of parameters `func` depends on but does not
|
129 |
+
explicitly take as arguments.
|
130 |
+
:param flag: if False, disable gradient checkpointing.
|
131 |
+
"""
|
132 |
+
if flag:
|
133 |
+
args = tuple(inputs) + tuple(params)
|
134 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
135 |
+
else:
|
136 |
+
return func(*inputs)
|
137 |
+
|
138 |
+
|
139 |
+
class CheckpointFunction(torch.autograd.Function):
|
140 |
+
@staticmethod
|
141 |
+
def forward(ctx, run_function, length, *args):
|
142 |
+
ctx.run_function = run_function
|
143 |
+
ctx.input_tensors = list(args[:length])
|
144 |
+
ctx.input_params = list(args[length:])
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
148 |
+
return output_tensors
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def backward(ctx, *output_grads):
|
152 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
153 |
+
with torch.enable_grad():
|
154 |
+
# Fixes a bug where the first op in run_function modifies the
|
155 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
156 |
+
# Tensors.
|
157 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
158 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
159 |
+
input_grads = torch.autograd.grad(
|
160 |
+
output_tensors,
|
161 |
+
ctx.input_tensors + ctx.input_params,
|
162 |
+
output_grads,
|
163 |
+
allow_unused=True,
|
164 |
+
)
|
165 |
+
del ctx.input_tensors
|
166 |
+
del ctx.input_params
|
167 |
+
del output_tensors
|
168 |
+
return (None, None) + input_grads
|
169 |
+
|
170 |
+
|
171 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
172 |
+
"""
|
173 |
+
Create sinusoidal timestep embeddings.
|
174 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
175 |
+
These may be fractional.
|
176 |
+
:param dim: the dimension of the output.
|
177 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
178 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
179 |
+
"""
|
180 |
+
if not repeat_only:
|
181 |
+
half = dim // 2
|
182 |
+
freqs = torch.exp(
|
183 |
+
-math.log(max_period)
|
184 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
185 |
+
/ half
|
186 |
+
).to(device=timesteps.device)
|
187 |
+
args = timesteps[:, None].float() * freqs[None]
|
188 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
189 |
+
if dim % 2:
|
190 |
+
embedding = torch.cat(
|
191 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
192 |
+
)
|
193 |
+
else:
|
194 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
195 |
+
return embedding
|
196 |
+
|
197 |
+
|
198 |
+
def zero_module(module):
|
199 |
+
"""
|
200 |
+
Zero out the parameters of a module and return it.
|
201 |
+
"""
|
202 |
+
for p in module.parameters():
|
203 |
+
p.detach().zero_()
|
204 |
+
return module
|
205 |
+
|
206 |
+
|
207 |
+
def scale_module(module, scale):
|
208 |
+
"""
|
209 |
+
Scale the parameters of a module and return it.
|
210 |
+
"""
|
211 |
+
for p in module.parameters():
|
212 |
+
p.detach().mul_(scale)
|
213 |
+
return module
|
214 |
+
|
215 |
+
|
216 |
+
def mean_flat(tensor):
|
217 |
+
"""
|
218 |
+
Take the mean over all non-batch dimensions.
|
219 |
+
"""
|
220 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
221 |
+
|
222 |
+
|
223 |
+
def normalization(channels):
|
224 |
+
"""
|
225 |
+
Make a standard normalization layer.
|
226 |
+
:param channels: number of input channels.
|
227 |
+
:return: an nn.Module for normalization.
|
228 |
+
"""
|
229 |
+
return GroupNorm32(32, channels)
|
230 |
+
|
231 |
+
|
232 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
233 |
+
class SiLU(nn.Module):
|
234 |
+
def forward(self, x):
|
235 |
+
return x * torch.sigmoid(x)
|
236 |
+
|
237 |
+
|
238 |
+
class GroupNorm32(nn.GroupNorm):
|
239 |
+
def forward(self, x):
|
240 |
+
return super().forward(x.float()).type(x.dtype)
|
241 |
+
|
242 |
+
|
243 |
+
def conv_nd(dims, *args, **kwargs):
|
244 |
+
"""
|
245 |
+
Create a 1D, 2D, or 3D convolution module.
|
246 |
+
"""
|
247 |
+
if dims == 1:
|
248 |
+
return nn.Conv1d(*args, **kwargs)
|
249 |
+
elif dims == 2:
|
250 |
+
return nn.Conv2d(*args, **kwargs)
|
251 |
+
elif dims == 3:
|
252 |
+
return nn.Conv3d(*args, **kwargs)
|
253 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
254 |
+
|
255 |
+
|
256 |
+
def linear(*args, **kwargs):
|
257 |
+
"""
|
258 |
+
Create a linear module.
|
259 |
+
"""
|
260 |
+
return nn.Linear(*args, **kwargs)
|
261 |
+
|
262 |
+
|
263 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
264 |
+
"""
|
265 |
+
Create a 1D, 2D, or 3D average pooling module.
|
266 |
+
"""
|
267 |
+
if dims == 1:
|
268 |
+
return nn.AvgPool1d(*args, **kwargs)
|
269 |
+
elif dims == 2:
|
270 |
+
return nn.AvgPool2d(*args, **kwargs)
|
271 |
+
elif dims == 3:
|
272 |
+
return nn.AvgPool3d(*args, **kwargs)
|
273 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
274 |
+
|
275 |
+
|
276 |
+
class HybridConditioner(nn.Module):
|
277 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
278 |
+
super().__init__()
|
279 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
280 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
281 |
+
|
282 |
+
def forward(self, c_concat, c_crossattn):
|
283 |
+
c_concat = self.concat_conditioner(c_concat)
|
284 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
285 |
+
return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
|
286 |
+
|
287 |
+
|
288 |
+
def noise_like(shape, device, repeat=False):
|
289 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
290 |
+
shape[0], *((1,) * (len(shape) - 1))
|
291 |
+
)
|
292 |
+
noise = lambda: torch.randn(shape, device=device)
|
293 |
+
return repeat_noise() if repeat else noise()
|
audioldm/stft.py
ADDED
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy.signal import get_window
|
5 |
+
from librosa.util import pad_center, tiny, normalize, pad_center
|
6 |
+
from librosa.filters import mel as librosa_mel_fn
|
7 |
+
|
8 |
+
|
9 |
+
def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
|
10 |
+
"""
|
11 |
+
Parameters
|
12 |
+
----------
|
13 |
+
C: compression factor
|
14 |
+
"""
|
15 |
+
return normalize_fun(torch.clamp(x, min=clip_val) * C)
|
16 |
+
|
17 |
+
|
18 |
+
def dynamic_range_decompression(x, C=1):
|
19 |
+
"""
|
20 |
+
Parameters
|
21 |
+
----------
|
22 |
+
C: compression factor used to compress
|
23 |
+
"""
|
24 |
+
return torch.exp(x) / C
|
25 |
+
|
26 |
+
|
27 |
+
def window_sumsquare(
|
28 |
+
window,
|
29 |
+
n_frames,
|
30 |
+
hop_length,
|
31 |
+
win_length,
|
32 |
+
n_fft,
|
33 |
+
dtype=np.float32,
|
34 |
+
norm=None,
|
35 |
+
):
|
36 |
+
"""
|
37 |
+
# from librosa 0.6
|
38 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
39 |
+
|
40 |
+
This is used to estimate modulation effects induced by windowing
|
41 |
+
observations in short-time fourier transforms.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
window : string, tuple, number, callable, or list-like
|
46 |
+
Window specification, as in `get_window`
|
47 |
+
|
48 |
+
n_frames : int > 0
|
49 |
+
The number of analysis frames
|
50 |
+
|
51 |
+
hop_length : int > 0
|
52 |
+
The number of samples to advance between frames
|
53 |
+
|
54 |
+
win_length : [optional]
|
55 |
+
The length of the window function. By default, this matches `n_fft`.
|
56 |
+
|
57 |
+
n_fft : int > 0
|
58 |
+
The length of each analysis frame.
|
59 |
+
|
60 |
+
dtype : np.dtype
|
61 |
+
The data type of the output
|
62 |
+
|
63 |
+
Returns
|
64 |
+
-------
|
65 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
66 |
+
The sum-squared envelope of the window function
|
67 |
+
"""
|
68 |
+
if win_length is None:
|
69 |
+
win_length = n_fft
|
70 |
+
|
71 |
+
n = n_fft + hop_length * (n_frames - 1)
|
72 |
+
x = np.zeros(n, dtype=dtype)
|
73 |
+
|
74 |
+
# Compute the squared window at the desired length
|
75 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
76 |
+
win_sq = normalize(win_sq, norm=norm) ** 2
|
77 |
+
win_sq = pad_center(win_sq, n_fft)
|
78 |
+
|
79 |
+
# Fill the envelope
|
80 |
+
for i in range(n_frames):
|
81 |
+
sample = i * hop_length
|
82 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class STFT(torch.nn.Module):
|
87 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
88 |
+
|
89 |
+
def __init__(self, filter_length, hop_length, win_length, window="hann"):
|
90 |
+
super(STFT, self).__init__()
|
91 |
+
self.filter_length = filter_length
|
92 |
+
self.hop_length = hop_length
|
93 |
+
self.win_length = win_length
|
94 |
+
self.window = window
|
95 |
+
self.forward_transform = None
|
96 |
+
scale = self.filter_length / self.hop_length
|
97 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
98 |
+
|
99 |
+
cutoff = int((self.filter_length / 2 + 1))
|
100 |
+
fourier_basis = np.vstack(
|
101 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
102 |
+
)
|
103 |
+
|
104 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
105 |
+
inverse_basis = torch.FloatTensor(
|
106 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
107 |
+
)
|
108 |
+
|
109 |
+
if window is not None:
|
110 |
+
assert filter_length >= win_length
|
111 |
+
# get window and zero center pad it to filter_length
|
112 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
113 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
114 |
+
fft_window = torch.from_numpy(fft_window).float()
|
115 |
+
|
116 |
+
# window the bases
|
117 |
+
forward_basis *= fft_window
|
118 |
+
inverse_basis *= fft_window
|
119 |
+
|
120 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
121 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
122 |
+
|
123 |
+
def transform(self, input_data):
|
124 |
+
device = self.forward_basis.device
|
125 |
+
input_data = input_data.to(device)
|
126 |
+
|
127 |
+
num_batches = input_data.size(0)
|
128 |
+
num_samples = input_data.size(1)
|
129 |
+
|
130 |
+
self.num_samples = num_samples
|
131 |
+
|
132 |
+
# similar to librosa, reflect-pad the input
|
133 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
134 |
+
input_data = F.pad(
|
135 |
+
input_data.unsqueeze(1),
|
136 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
137 |
+
mode="reflect",
|
138 |
+
)
|
139 |
+
input_data = input_data.squeeze(1)
|
140 |
+
|
141 |
+
forward_transform = F.conv1d(
|
142 |
+
input_data,
|
143 |
+
torch.autograd.Variable(self.forward_basis, requires_grad=False),
|
144 |
+
stride=self.hop_length,
|
145 |
+
padding=0,
|
146 |
+
)
|
147 |
+
|
148 |
+
cutoff = int((self.filter_length / 2) + 1)
|
149 |
+
real_part = forward_transform[:, :cutoff, :]
|
150 |
+
imag_part = forward_transform[:, cutoff:, :]
|
151 |
+
|
152 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
153 |
+
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
|
154 |
+
|
155 |
+
return magnitude, phase
|
156 |
+
|
157 |
+
def inverse(self, magnitude, phase):
|
158 |
+
device = self.forward_basis.device
|
159 |
+
magnitude, phase = magnitude.to(device), phase.to(device)
|
160 |
+
|
161 |
+
recombine_magnitude_phase = torch.cat(
|
162 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
163 |
+
)
|
164 |
+
|
165 |
+
inverse_transform = F.conv_transpose1d(
|
166 |
+
recombine_magnitude_phase,
|
167 |
+
torch.autograd.Variable(self.inverse_basis, requires_grad=False),
|
168 |
+
stride=self.hop_length,
|
169 |
+
padding=0,
|
170 |
+
)
|
171 |
+
|
172 |
+
if self.window is not None:
|
173 |
+
window_sum = window_sumsquare(
|
174 |
+
self.window,
|
175 |
+
magnitude.size(-1),
|
176 |
+
hop_length=self.hop_length,
|
177 |
+
win_length=self.win_length,
|
178 |
+
n_fft=self.filter_length,
|
179 |
+
dtype=np.float32,
|
180 |
+
)
|
181 |
+
# remove modulation effects
|
182 |
+
approx_nonzero_indices = torch.from_numpy(
|
183 |
+
np.where(window_sum > tiny(window_sum))[0]
|
184 |
+
)
|
185 |
+
window_sum = torch.autograd.Variable(
|
186 |
+
torch.from_numpy(window_sum), requires_grad=False
|
187 |
+
)
|
188 |
+
window_sum = window_sum
|
189 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
190 |
+
approx_nonzero_indices
|
191 |
+
]
|
192 |
+
|
193 |
+
# scale by hop ratio
|
194 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
195 |
+
|
196 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
197 |
+
inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
|
198 |
+
|
199 |
+
return inverse_transform
|
200 |
+
|
201 |
+
def forward(self, input_data):
|
202 |
+
self.magnitude, self.phase = self.transform(input_data)
|
203 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
204 |
+
return reconstruction
|
205 |
+
|
206 |
+
|
207 |
+
class TacotronSTFT(torch.nn.Module):
|
208 |
+
def __init__(
|
209 |
+
self,
|
210 |
+
filter_length,
|
211 |
+
hop_length,
|
212 |
+
win_length,
|
213 |
+
n_mel_channels,
|
214 |
+
sampling_rate,
|
215 |
+
mel_fmin,
|
216 |
+
mel_fmax,
|
217 |
+
):
|
218 |
+
super(TacotronSTFT, self).__init__()
|
219 |
+
self.n_mel_channels = n_mel_channels
|
220 |
+
self.sampling_rate = sampling_rate
|
221 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
222 |
+
mel_basis = librosa_mel_fn(
|
223 |
+
sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
|
224 |
+
)
|
225 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
226 |
+
self.register_buffer("mel_basis", mel_basis)
|
227 |
+
|
228 |
+
def spectral_normalize(self, magnitudes, normalize_fun):
|
229 |
+
output = dynamic_range_compression(magnitudes, normalize_fun)
|
230 |
+
return output
|
231 |
+
|
232 |
+
def spectral_de_normalize(self, magnitudes):
|
233 |
+
output = dynamic_range_decompression(magnitudes)
|
234 |
+
return output
|
235 |
+
|
236 |
+
def mel_spectrogram(self, y, normalize_fun=torch.log):
|
237 |
+
"""Computes mel-spectrograms from a batch of waves
|
238 |
+
PARAMS
|
239 |
+
------
|
240 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
241 |
+
|
242 |
+
RETURNS
|
243 |
+
-------
|
244 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
245 |
+
"""
|
246 |
+
assert torch.min(y.data) >= -1, torch.min(y.data)
|
247 |
+
assert torch.max(y.data) <= 1, torch.max(y.data)
|
248 |
+
|
249 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
250 |
+
magnitudes = magnitudes.data
|
251 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
252 |
+
mel_output = self.spectral_normalize(mel_output, normalize_fun)
|
253 |
+
energy = torch.norm(magnitudes, dim=1)
|
254 |
+
|
255 |
+
log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
|
256 |
+
|
257 |
+
return mel_output, log_magnitudes, energy
|
audioldm/utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
|
5 |
+
CACHE_DIR = os.getenv(
|
6 |
+
"AUDIOLDM_CACHE_DIR",
|
7 |
+
os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
|
8 |
+
|
9 |
+
|
10 |
+
def default_audioldm_config(model_name="audioldm-s-full"):
|
11 |
+
basic_config = {
|
12 |
+
"wave_file_save_path": "./output",
|
13 |
+
"id": {
|
14 |
+
"version": "v1",
|
15 |
+
"name": "default",
|
16 |
+
"root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
|
17 |
+
},
|
18 |
+
"preprocessing": {
|
19 |
+
"audio": {"sampling_rate": 16000, "max_wav_value": 32768},
|
20 |
+
"stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
|
21 |
+
"mel": {
|
22 |
+
"n_mel_channels": 64,
|
23 |
+
"mel_fmin": 0,
|
24 |
+
"mel_fmax": 8000,
|
25 |
+
"freqm": 0,
|
26 |
+
"timem": 0,
|
27 |
+
"blur": False,
|
28 |
+
"mean": -4.63,
|
29 |
+
"std": 2.74,
|
30 |
+
"target_length": 1024,
|
31 |
+
},
|
32 |
+
},
|
33 |
+
"model": {
|
34 |
+
"device": "cuda",
|
35 |
+
"target": "audioldm.pipline.LatentDiffusion",
|
36 |
+
"params": {
|
37 |
+
"base_learning_rate": 5e-06,
|
38 |
+
"linear_start": 0.0015,
|
39 |
+
"linear_end": 0.0195,
|
40 |
+
"num_timesteps_cond": 1,
|
41 |
+
"log_every_t": 200,
|
42 |
+
"timesteps": 1000,
|
43 |
+
"first_stage_key": "fbank",
|
44 |
+
"cond_stage_key": "waveform",
|
45 |
+
"latent_t_size": 256,
|
46 |
+
"latent_f_size": 16,
|
47 |
+
"channels": 8,
|
48 |
+
"cond_stage_trainable": True,
|
49 |
+
"conditioning_key": "film",
|
50 |
+
"monitor": "val/loss_simple_ema",
|
51 |
+
"scale_by_std": True,
|
52 |
+
"unet_config": {
|
53 |
+
"target": "audioldm.latent_diffusion.openaimodel.UNetModel",
|
54 |
+
"params": {
|
55 |
+
"image_size": 64,
|
56 |
+
"extra_film_condition_dim": 512,
|
57 |
+
"extra_film_use_concat": True,
|
58 |
+
"in_channels": 8,
|
59 |
+
"out_channels": 8,
|
60 |
+
"model_channels": 128,
|
61 |
+
"attention_resolutions": [8, 4, 2],
|
62 |
+
"num_res_blocks": 2,
|
63 |
+
"channel_mult": [1, 2, 3, 5],
|
64 |
+
"num_head_channels": 32,
|
65 |
+
"use_spatial_transformer": True,
|
66 |
+
},
|
67 |
+
},
|
68 |
+
"first_stage_config": {
|
69 |
+
"base_learning_rate": 4.5e-05,
|
70 |
+
"target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
|
71 |
+
"params": {
|
72 |
+
"monitor": "val/rec_loss",
|
73 |
+
"image_key": "fbank",
|
74 |
+
"subband": 1,
|
75 |
+
"embed_dim": 8,
|
76 |
+
"time_shuffle": 1,
|
77 |
+
"ddconfig": {
|
78 |
+
"double_z": True,
|
79 |
+
"z_channels": 8,
|
80 |
+
"resolution": 256,
|
81 |
+
"downsample_time": False,
|
82 |
+
"in_channels": 1,
|
83 |
+
"out_ch": 1,
|
84 |
+
"ch": 128,
|
85 |
+
"ch_mult": [1, 2, 4],
|
86 |
+
"num_res_blocks": 2,
|
87 |
+
"attn_resolutions": [],
|
88 |
+
"dropout": 0.0,
|
89 |
+
},
|
90 |
+
},
|
91 |
+
},
|
92 |
+
"cond_stage_config": {
|
93 |
+
"target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
|
94 |
+
"params": {
|
95 |
+
"key": "waveform",
|
96 |
+
"sampling_rate": 16000,
|
97 |
+
"embed_mode": "audio",
|
98 |
+
"unconditional_prob": 0.1,
|
99 |
+
},
|
100 |
+
},
|
101 |
+
},
|
102 |
+
},
|
103 |
+
}
|
104 |
+
|
105 |
+
if("-l-" in model_name):
|
106 |
+
basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
|
107 |
+
basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
|
108 |
+
elif("-m-" in model_name):
|
109 |
+
basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
|
110 |
+
basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
|
111 |
+
|
112 |
+
return basic_config
|
113 |
+
|
114 |
+
|
115 |
+
def get_metadata():
|
116 |
+
return {
|
117 |
+
"audioldm-s-full": {
|
118 |
+
"path": os.path.join(
|
119 |
+
CACHE_DIR,
|
120 |
+
"audioldm-s-full.ckpt",
|
121 |
+
),
|
122 |
+
"url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
|
123 |
+
},
|
124 |
+
"audioldm-l-full": {
|
125 |
+
"path": os.path.join(
|
126 |
+
CACHE_DIR,
|
127 |
+
"audioldm-l-full.ckpt",
|
128 |
+
),
|
129 |
+
"url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
|
130 |
+
},
|
131 |
+
"audioldm-s-full-v2": {
|
132 |
+
"path": os.path.join(
|
133 |
+
CACHE_DIR,
|
134 |
+
"audioldm-s-full-v2.ckpt",
|
135 |
+
),
|
136 |
+
"url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
|
137 |
+
},
|
138 |
+
"audioldm-m-text-ft": {
|
139 |
+
"path": os.path.join(
|
140 |
+
CACHE_DIR,
|
141 |
+
"audioldm-m-text-ft.ckpt",
|
142 |
+
),
|
143 |
+
"url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
|
144 |
+
},
|
145 |
+
"audioldm-s-text-ft": {
|
146 |
+
"path": os.path.join(
|
147 |
+
CACHE_DIR,
|
148 |
+
"audioldm-s-text-ft.ckpt",
|
149 |
+
),
|
150 |
+
"url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
|
151 |
+
},
|
152 |
+
"audioldm-m-full": {
|
153 |
+
"path": os.path.join(
|
154 |
+
CACHE_DIR,
|
155 |
+
"audioldm-m-full.ckpt",
|
156 |
+
),
|
157 |
+
"url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
|
158 |
+
},
|
159 |
+
}
|
160 |
+
|
161 |
+
|
162 |
+
def get_obj_from_str(string, reload=False):
|
163 |
+
module, cls = string.rsplit(".", 1)
|
164 |
+
if reload:
|
165 |
+
module_imp = importlib.import_module(module)
|
166 |
+
importlib.reload(module_imp)
|
167 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
168 |
+
|
169 |
+
|
170 |
+
def instantiate_from_config(config):
|
171 |
+
if not "target" in config:
|
172 |
+
if config == "__is_first_stage__":
|
173 |
+
return None
|
174 |
+
elif config == "__is_unconditional__":
|
175 |
+
return None
|
176 |
+
raise KeyError("Expected key `target` to instantiate.")
|
177 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
audioldm/variational_autoencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .autoencoder import AutoencoderKL
|
audioldm/variational_autoencoder/autoencoder.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from audioldm.variational_autoencoder.modules import Encoder, Decoder
|
5 |
+
from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
|
6 |
+
from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
|
7 |
+
|
8 |
+
|
9 |
+
class AutoencoderKL(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
ddconfig=None,
|
13 |
+
lossconfig=None,
|
14 |
+
image_key="fbank",
|
15 |
+
embed_dim=None,
|
16 |
+
time_shuffle=1,
|
17 |
+
subband=1,
|
18 |
+
ckpt_path=None,
|
19 |
+
reload_from_ckpt=None,
|
20 |
+
ignore_keys=[],
|
21 |
+
colorize_nlabels=None,
|
22 |
+
monitor=None,
|
23 |
+
base_learning_rate=1e-5,
|
24 |
+
scale_factor=1
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.encoder = Encoder(**ddconfig)
|
29 |
+
self.decoder = Decoder(**ddconfig)
|
30 |
+
self.ema_decoder = None
|
31 |
+
|
32 |
+
self.subband = int(subband)
|
33 |
+
if self.subband > 1:
|
34 |
+
print("Use subband decomposition %s" % self.subband)
|
35 |
+
|
36 |
+
self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
37 |
+
self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
38 |
+
self.ema_post_quant_conv = None
|
39 |
+
|
40 |
+
self.vocoder = get_vocoder(None, "cpu")
|
41 |
+
self.embed_dim = embed_dim
|
42 |
+
|
43 |
+
if monitor is not None:
|
44 |
+
self.monitor = monitor
|
45 |
+
|
46 |
+
self.time_shuffle = time_shuffle
|
47 |
+
self.reload_from_ckpt = reload_from_ckpt
|
48 |
+
self.reloaded = False
|
49 |
+
self.mean, self.std = None, None
|
50 |
+
|
51 |
+
self.scale_factor = scale_factor
|
52 |
+
|
53 |
+
@property
|
54 |
+
def device(self):
|
55 |
+
return next(self.parameters()).device
|
56 |
+
|
57 |
+
def freq_split_subband(self, fbank):
|
58 |
+
if self.subband == 1 or self.image_key != "stft":
|
59 |
+
return fbank
|
60 |
+
|
61 |
+
bs, ch, tstep, fbins = fbank.size()
|
62 |
+
|
63 |
+
assert fbank.size(-1) % self.subband == 0
|
64 |
+
assert ch == 1
|
65 |
+
|
66 |
+
return (
|
67 |
+
fbank.squeeze(1)
|
68 |
+
.reshape(bs, tstep, self.subband, fbins // self.subband)
|
69 |
+
.permute(0, 2, 1, 3)
|
70 |
+
)
|
71 |
+
|
72 |
+
def freq_merge_subband(self, subband_fbank):
|
73 |
+
if self.subband == 1 or self.image_key != "stft":
|
74 |
+
return subband_fbank
|
75 |
+
assert subband_fbank.size(1) == self.subband # Channel dimension
|
76 |
+
bs, sub_ch, tstep, fbins = subband_fbank.size()
|
77 |
+
return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
|
78 |
+
|
79 |
+
def encode(self, x):
|
80 |
+
x = self.freq_split_subband(x)
|
81 |
+
h = self.encoder(x)
|
82 |
+
moments = self.quant_conv(h)
|
83 |
+
posterior = DiagonalGaussianDistribution(moments)
|
84 |
+
return posterior
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def encode_first_stage(self, x):
|
88 |
+
return self.encode(x)
|
89 |
+
|
90 |
+
def decode(self, z, use_ema=False):
|
91 |
+
if use_ema and (not hasattr(self, 'ema_decoder') or self.ema_decoder is None):
|
92 |
+
print("VAE does not have EMA modules, but specified use_ema. "
|
93 |
+
"Using the none-EMA modules instead.")
|
94 |
+
if use_ema and hasattr(self, 'ema_decoder') and self.ema_decoder is not None:
|
95 |
+
z = self.ema_post_quant_conv(z)
|
96 |
+
dec = self.ema_decoder(z)
|
97 |
+
else:
|
98 |
+
z = self.post_quant_conv(z)
|
99 |
+
dec = self.decoder(z)
|
100 |
+
return self.freq_merge_subband(dec)
|
101 |
+
|
102 |
+
def decode_first_stage(self, z, allow_grad=False, use_ema=False):
|
103 |
+
with torch.set_grad_enabled(allow_grad):
|
104 |
+
z = z / self.scale_factor
|
105 |
+
return self.decode(z, use_ema)
|
106 |
+
|
107 |
+
def decode_to_waveform(self, dec, allow_grad=False):
|
108 |
+
dec = dec.squeeze(1).permute(0, 2, 1)
|
109 |
+
wav_reconstruction = vocoder_infer(dec, self.vocoder, allow_grad=allow_grad)
|
110 |
+
return wav_reconstruction
|
111 |
+
|
112 |
+
def forward(self, input, sample_posterior=True):
|
113 |
+
posterior = self.encode(input)
|
114 |
+
z = posterior.sample() if sample_posterior else posterior.mode()
|
115 |
+
|
116 |
+
if self.flag_first_run:
|
117 |
+
print("Latent size: ", z.size())
|
118 |
+
self.flag_first_run = False
|
119 |
+
|
120 |
+
return self.decode(z), posterior
|
121 |
+
|
122 |
+
def get_first_stage_encoding(self, encoder_posterior):
|
123 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
124 |
+
z = encoder_posterior.sample()
|
125 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
126 |
+
z = encoder_posterior
|
127 |
+
else:
|
128 |
+
raise NotImplementedError(
|
129 |
+
f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
|
130 |
+
)
|
131 |
+
return self.scale_factor * z
|
audioldm/variational_autoencoder/distributions.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractDistribution:
|
6 |
+
def sample(self):
|
7 |
+
raise NotImplementedError()
|
8 |
+
|
9 |
+
def mode(self):
|
10 |
+
raise NotImplementedError()
|
11 |
+
|
12 |
+
|
13 |
+
class DiracDistribution(AbstractDistribution):
|
14 |
+
def __init__(self, value):
|
15 |
+
self.value = value
|
16 |
+
|
17 |
+
def sample(self):
|
18 |
+
return self.value
|
19 |
+
|
20 |
+
def mode(self):
|
21 |
+
return self.value
|
22 |
+
|
23 |
+
|
24 |
+
class DiagonalGaussianDistribution(object):
|
25 |
+
def __init__(self, parameters, deterministic=False):
|
26 |
+
self.parameters = parameters
|
27 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
28 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
29 |
+
self.deterministic = deterministic
|
30 |
+
self.std = torch.exp(0.5 * self.logvar)
|
31 |
+
self.var = torch.exp(self.logvar)
|
32 |
+
if self.deterministic:
|
33 |
+
self.var = self.std = torch.zeros_like(self.mean).to(
|
34 |
+
device=self.parameters.device
|
35 |
+
)
|
36 |
+
|
37 |
+
def sample(self):
|
38 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
39 |
+
device=self.parameters.device
|
40 |
+
)
|
41 |
+
return x
|
42 |
+
|
43 |
+
def kl(self, other=None):
|
44 |
+
if self.deterministic:
|
45 |
+
return torch.Tensor([0.0])
|
46 |
+
else:
|
47 |
+
if other is None:
|
48 |
+
return 0.5 * torch.mean(
|
49 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
50 |
+
dim=[1, 2, 3],
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
return 0.5 * torch.mean(
|
54 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
55 |
+
+ self.var / other.var
|
56 |
+
- 1.0
|
57 |
+
- self.logvar
|
58 |
+
+ other.logvar,
|
59 |
+
dim=[1, 2, 3],
|
60 |
+
)
|
61 |
+
|
62 |
+
def nll(self, sample, dims=[1, 2, 3]):
|
63 |
+
if self.deterministic:
|
64 |
+
return torch.Tensor([0.0])
|
65 |
+
logtwopi = np.log(2.0 * np.pi)
|
66 |
+
return 0.5 * torch.sum(
|
67 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
68 |
+
dim=dims,
|
69 |
+
)
|
70 |
+
|
71 |
+
def mode(self):
|
72 |
+
return self.mean
|
73 |
+
|
74 |
+
|
75 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
76 |
+
"""
|
77 |
+
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
78 |
+
Compute the KL divergence between two gaussians.
|
79 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
80 |
+
scalars, among other use cases.
|
81 |
+
"""
|
82 |
+
tensor = None
|
83 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
84 |
+
if isinstance(obj, torch.Tensor):
|
85 |
+
tensor = obj
|
86 |
+
break
|
87 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
88 |
+
|
89 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
90 |
+
# Tensors, but it does not work for torch.exp().
|
91 |
+
logvar1, logvar2 = [
|
92 |
+
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
93 |
+
for x in (logvar1, logvar2)
|
94 |
+
]
|
95 |
+
|
96 |
+
return 0.5 * (
|
97 |
+
-1.0
|
98 |
+
+ logvar2
|
99 |
+
- logvar1
|
100 |
+
+ torch.exp(logvar1 - logvar2)
|
101 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
102 |
+
)
|
audioldm/variational_autoencoder/modules.py
ADDED
@@ -0,0 +1,1067 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch_diffusion + derived encoder decoder
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from audioldm.utils import instantiate_from_config
|
9 |
+
from audioldm.latent_diffusion.attention import LinearAttention
|
10 |
+
|
11 |
+
|
12 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
13 |
+
"""
|
14 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
15 |
+
From Fairseq.
|
16 |
+
Build sinusoidal embeddings.
|
17 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
18 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
19 |
+
"""
|
20 |
+
assert len(timesteps.shape) == 1
|
21 |
+
|
22 |
+
half_dim = embedding_dim // 2
|
23 |
+
emb = math.log(10000) / (half_dim - 1)
|
24 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
25 |
+
emb = emb.to(device=timesteps.device)
|
26 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
27 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
28 |
+
if embedding_dim % 2 == 1: # zero pad
|
29 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
def nonlinearity(x):
|
34 |
+
# swish
|
35 |
+
return x * torch.sigmoid(x)
|
36 |
+
|
37 |
+
|
38 |
+
def Normalize(in_channels, num_groups=32):
|
39 |
+
return torch.nn.GroupNorm(
|
40 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class Upsample(nn.Module):
|
45 |
+
def __init__(self, in_channels, with_conv):
|
46 |
+
super().__init__()
|
47 |
+
self.with_conv = with_conv
|
48 |
+
if self.with_conv:
|
49 |
+
self.conv = torch.nn.Conv2d(
|
50 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
55 |
+
if self.with_conv:
|
56 |
+
x = self.conv(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class UpsampleTimeStride4(nn.Module):
|
61 |
+
def __init__(self, in_channels, with_conv):
|
62 |
+
super().__init__()
|
63 |
+
self.with_conv = with_conv
|
64 |
+
if self.with_conv:
|
65 |
+
self.conv = torch.nn.Conv2d(
|
66 |
+
in_channels, in_channels, kernel_size=5, stride=1, padding=2
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
|
71 |
+
if self.with_conv:
|
72 |
+
x = self.conv(x)
|
73 |
+
return x
|
74 |
+
|
75 |
+
|
76 |
+
class Downsample(nn.Module):
|
77 |
+
def __init__(self, in_channels, with_conv):
|
78 |
+
super().__init__()
|
79 |
+
self.with_conv = with_conv
|
80 |
+
if self.with_conv:
|
81 |
+
# Do time downsampling here
|
82 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
83 |
+
self.conv = torch.nn.Conv2d(
|
84 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
85 |
+
)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
if self.with_conv:
|
89 |
+
pad = (0, 1, 0, 1)
|
90 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
91 |
+
x = self.conv(x)
|
92 |
+
else:
|
93 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
class DownsampleTimeStride4(nn.Module):
|
98 |
+
def __init__(self, in_channels, with_conv):
|
99 |
+
super().__init__()
|
100 |
+
self.with_conv = with_conv
|
101 |
+
if self.with_conv:
|
102 |
+
# Do time downsampling here
|
103 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
104 |
+
self.conv = torch.nn.Conv2d(
|
105 |
+
in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
|
106 |
+
)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
if self.with_conv:
|
110 |
+
pad = (0, 1, 0, 1)
|
111 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
112 |
+
x = self.conv(x)
|
113 |
+
else:
|
114 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
class ResnetBlock(nn.Module):
|
119 |
+
def __init__(
|
120 |
+
self,
|
121 |
+
*,
|
122 |
+
in_channels,
|
123 |
+
out_channels=None,
|
124 |
+
conv_shortcut=False,
|
125 |
+
dropout,
|
126 |
+
temb_channels=512,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.in_channels = in_channels
|
130 |
+
out_channels = in_channels if out_channels is None else out_channels
|
131 |
+
self.out_channels = out_channels
|
132 |
+
self.use_conv_shortcut = conv_shortcut
|
133 |
+
|
134 |
+
self.norm1 = Normalize(in_channels)
|
135 |
+
self.conv1 = torch.nn.Conv2d(
|
136 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
137 |
+
)
|
138 |
+
if temb_channels > 0:
|
139 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
140 |
+
self.norm2 = Normalize(out_channels)
|
141 |
+
self.dropout = torch.nn.Dropout(dropout)
|
142 |
+
self.conv2 = torch.nn.Conv2d(
|
143 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
144 |
+
)
|
145 |
+
if self.in_channels != self.out_channels:
|
146 |
+
if self.use_conv_shortcut:
|
147 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
148 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
149 |
+
)
|
150 |
+
else:
|
151 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
152 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x, temb):
|
156 |
+
h = x
|
157 |
+
h = self.norm1(h)
|
158 |
+
h = nonlinearity(h)
|
159 |
+
h = self.conv1(h)
|
160 |
+
|
161 |
+
if temb is not None:
|
162 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
163 |
+
|
164 |
+
h = self.norm2(h)
|
165 |
+
h = nonlinearity(h)
|
166 |
+
h = self.dropout(h)
|
167 |
+
h = self.conv2(h)
|
168 |
+
|
169 |
+
if self.in_channels != self.out_channels:
|
170 |
+
if self.use_conv_shortcut:
|
171 |
+
x = self.conv_shortcut(x)
|
172 |
+
else:
|
173 |
+
x = self.nin_shortcut(x)
|
174 |
+
|
175 |
+
return x + h
|
176 |
+
|
177 |
+
|
178 |
+
class LinAttnBlock(LinearAttention):
|
179 |
+
"""to match AttnBlock usage"""
|
180 |
+
|
181 |
+
def __init__(self, in_channels):
|
182 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
183 |
+
|
184 |
+
|
185 |
+
class AttnBlock(nn.Module):
|
186 |
+
def __init__(self, in_channels):
|
187 |
+
super().__init__()
|
188 |
+
self.in_channels = in_channels
|
189 |
+
|
190 |
+
self.norm = Normalize(in_channels)
|
191 |
+
self.q = torch.nn.Conv2d(
|
192 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
193 |
+
)
|
194 |
+
self.k = torch.nn.Conv2d(
|
195 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
196 |
+
)
|
197 |
+
self.v = torch.nn.Conv2d(
|
198 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
199 |
+
)
|
200 |
+
self.proj_out = torch.nn.Conv2d(
|
201 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
202 |
+
)
|
203 |
+
|
204 |
+
def forward(self, x):
|
205 |
+
h_ = x
|
206 |
+
h_ = self.norm(h_)
|
207 |
+
q = self.q(h_)
|
208 |
+
k = self.k(h_)
|
209 |
+
v = self.v(h_)
|
210 |
+
|
211 |
+
# compute attention
|
212 |
+
b, c, h, w = q.shape
|
213 |
+
q = q.reshape(b, c, h * w).contiguous()
|
214 |
+
q = q.permute(0, 2, 1).contiguous() # b,hw,c
|
215 |
+
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
|
216 |
+
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
217 |
+
w_ = w_ * (int(c) ** (-0.5))
|
218 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
219 |
+
|
220 |
+
# attend to values
|
221 |
+
v = v.reshape(b, c, h * w).contiguous()
|
222 |
+
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
|
223 |
+
h_ = torch.bmm(
|
224 |
+
v, w_
|
225 |
+
).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
226 |
+
h_ = h_.reshape(b, c, h, w).contiguous()
|
227 |
+
|
228 |
+
h_ = self.proj_out(h_)
|
229 |
+
|
230 |
+
return x + h_
|
231 |
+
|
232 |
+
|
233 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
234 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
235 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
236 |
+
if attn_type == "vanilla":
|
237 |
+
return AttnBlock(in_channels)
|
238 |
+
elif attn_type == "none":
|
239 |
+
return nn.Identity(in_channels)
|
240 |
+
else:
|
241 |
+
return LinAttnBlock(in_channels)
|
242 |
+
|
243 |
+
|
244 |
+
class Model(nn.Module):
|
245 |
+
def __init__(
|
246 |
+
self,
|
247 |
+
*,
|
248 |
+
ch,
|
249 |
+
out_ch,
|
250 |
+
ch_mult=(1, 2, 4, 8),
|
251 |
+
num_res_blocks,
|
252 |
+
attn_resolutions,
|
253 |
+
dropout=0.0,
|
254 |
+
resamp_with_conv=True,
|
255 |
+
in_channels,
|
256 |
+
resolution,
|
257 |
+
use_timestep=True,
|
258 |
+
use_linear_attn=False,
|
259 |
+
attn_type="vanilla",
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
if use_linear_attn:
|
263 |
+
attn_type = "linear"
|
264 |
+
self.ch = ch
|
265 |
+
self.temb_ch = self.ch * 4
|
266 |
+
self.num_resolutions = len(ch_mult)
|
267 |
+
self.num_res_blocks = num_res_blocks
|
268 |
+
self.resolution = resolution
|
269 |
+
self.in_channels = in_channels
|
270 |
+
|
271 |
+
self.use_timestep = use_timestep
|
272 |
+
if self.use_timestep:
|
273 |
+
# timestep embedding
|
274 |
+
self.temb = nn.Module()
|
275 |
+
self.temb.dense = nn.ModuleList(
|
276 |
+
[
|
277 |
+
torch.nn.Linear(self.ch, self.temb_ch),
|
278 |
+
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
279 |
+
]
|
280 |
+
)
|
281 |
+
|
282 |
+
# downsampling
|
283 |
+
self.conv_in = torch.nn.Conv2d(
|
284 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
285 |
+
)
|
286 |
+
|
287 |
+
curr_res = resolution
|
288 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
289 |
+
self.down = nn.ModuleList()
|
290 |
+
for i_level in range(self.num_resolutions):
|
291 |
+
block = nn.ModuleList()
|
292 |
+
attn = nn.ModuleList()
|
293 |
+
block_in = ch * in_ch_mult[i_level]
|
294 |
+
block_out = ch * ch_mult[i_level]
|
295 |
+
for i_block in range(self.num_res_blocks):
|
296 |
+
block.append(
|
297 |
+
ResnetBlock(
|
298 |
+
in_channels=block_in,
|
299 |
+
out_channels=block_out,
|
300 |
+
temb_channels=self.temb_ch,
|
301 |
+
dropout=dropout,
|
302 |
+
)
|
303 |
+
)
|
304 |
+
block_in = block_out
|
305 |
+
if curr_res in attn_resolutions:
|
306 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
307 |
+
down = nn.Module()
|
308 |
+
down.block = block
|
309 |
+
down.attn = attn
|
310 |
+
if i_level != self.num_resolutions - 1:
|
311 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
312 |
+
curr_res = curr_res // 2
|
313 |
+
self.down.append(down)
|
314 |
+
|
315 |
+
# middle
|
316 |
+
self.mid = nn.Module()
|
317 |
+
self.mid.block_1 = ResnetBlock(
|
318 |
+
in_channels=block_in,
|
319 |
+
out_channels=block_in,
|
320 |
+
temb_channels=self.temb_ch,
|
321 |
+
dropout=dropout,
|
322 |
+
)
|
323 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
324 |
+
self.mid.block_2 = ResnetBlock(
|
325 |
+
in_channels=block_in,
|
326 |
+
out_channels=block_in,
|
327 |
+
temb_channels=self.temb_ch,
|
328 |
+
dropout=dropout,
|
329 |
+
)
|
330 |
+
|
331 |
+
# upsampling
|
332 |
+
self.up = nn.ModuleList()
|
333 |
+
for i_level in reversed(range(self.num_resolutions)):
|
334 |
+
block = nn.ModuleList()
|
335 |
+
attn = nn.ModuleList()
|
336 |
+
block_out = ch * ch_mult[i_level]
|
337 |
+
skip_in = ch * ch_mult[i_level]
|
338 |
+
for i_block in range(self.num_res_blocks + 1):
|
339 |
+
if i_block == self.num_res_blocks:
|
340 |
+
skip_in = ch * in_ch_mult[i_level]
|
341 |
+
block.append(
|
342 |
+
ResnetBlock(
|
343 |
+
in_channels=block_in + skip_in,
|
344 |
+
out_channels=block_out,
|
345 |
+
temb_channels=self.temb_ch,
|
346 |
+
dropout=dropout,
|
347 |
+
)
|
348 |
+
)
|
349 |
+
block_in = block_out
|
350 |
+
if curr_res in attn_resolutions:
|
351 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
352 |
+
up = nn.Module()
|
353 |
+
up.block = block
|
354 |
+
up.attn = attn
|
355 |
+
if i_level != 0:
|
356 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
357 |
+
curr_res = curr_res * 2
|
358 |
+
self.up.insert(0, up) # prepend to get consistent order
|
359 |
+
|
360 |
+
# end
|
361 |
+
self.norm_out = Normalize(block_in)
|
362 |
+
self.conv_out = torch.nn.Conv2d(
|
363 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
364 |
+
)
|
365 |
+
|
366 |
+
def forward(self, x, t=None, context=None):
|
367 |
+
# assert x.shape[2] == x.shape[3] == self.resolution
|
368 |
+
if context is not None:
|
369 |
+
# assume aligned context, cat along channel axis
|
370 |
+
x = torch.cat((x, context), dim=1)
|
371 |
+
if self.use_timestep:
|
372 |
+
# timestep embedding
|
373 |
+
assert t is not None
|
374 |
+
temb = get_timestep_embedding(t, self.ch)
|
375 |
+
temb = self.temb.dense[0](temb)
|
376 |
+
temb = nonlinearity(temb)
|
377 |
+
temb = self.temb.dense[1](temb)
|
378 |
+
else:
|
379 |
+
temb = None
|
380 |
+
|
381 |
+
# downsampling
|
382 |
+
hs = [self.conv_in(x)]
|
383 |
+
for i_level in range(self.num_resolutions):
|
384 |
+
for i_block in range(self.num_res_blocks):
|
385 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
386 |
+
if len(self.down[i_level].attn) > 0:
|
387 |
+
h = self.down[i_level].attn[i_block](h)
|
388 |
+
hs.append(h)
|
389 |
+
if i_level != self.num_resolutions - 1:
|
390 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
391 |
+
|
392 |
+
# middle
|
393 |
+
h = hs[-1]
|
394 |
+
h = self.mid.block_1(h, temb)
|
395 |
+
h = self.mid.attn_1(h)
|
396 |
+
h = self.mid.block_2(h, temb)
|
397 |
+
|
398 |
+
# upsampling
|
399 |
+
for i_level in reversed(range(self.num_resolutions)):
|
400 |
+
for i_block in range(self.num_res_blocks + 1):
|
401 |
+
h = self.up[i_level].block[i_block](
|
402 |
+
torch.cat([h, hs.pop()], dim=1), temb
|
403 |
+
)
|
404 |
+
if len(self.up[i_level].attn) > 0:
|
405 |
+
h = self.up[i_level].attn[i_block](h)
|
406 |
+
if i_level != 0:
|
407 |
+
h = self.up[i_level].upsample(h)
|
408 |
+
|
409 |
+
# end
|
410 |
+
h = self.norm_out(h)
|
411 |
+
h = nonlinearity(h)
|
412 |
+
h = self.conv_out(h)
|
413 |
+
return h
|
414 |
+
|
415 |
+
def get_last_layer(self):
|
416 |
+
return self.conv_out.weight
|
417 |
+
|
418 |
+
|
419 |
+
class Encoder(nn.Module):
|
420 |
+
def __init__(
|
421 |
+
self,
|
422 |
+
*,
|
423 |
+
ch,
|
424 |
+
out_ch,
|
425 |
+
ch_mult=(1, 2, 4, 8),
|
426 |
+
num_res_blocks,
|
427 |
+
attn_resolutions,
|
428 |
+
dropout=0.0,
|
429 |
+
resamp_with_conv=True,
|
430 |
+
in_channels,
|
431 |
+
resolution,
|
432 |
+
z_channels,
|
433 |
+
double_z=True,
|
434 |
+
use_linear_attn=False,
|
435 |
+
attn_type="vanilla",
|
436 |
+
downsample_time_stride4_levels=[],
|
437 |
+
**ignore_kwargs,
|
438 |
+
):
|
439 |
+
super().__init__()
|
440 |
+
if use_linear_attn:
|
441 |
+
attn_type = "linear"
|
442 |
+
self.ch = ch
|
443 |
+
self.temb_ch = 0
|
444 |
+
self.num_resolutions = len(ch_mult)
|
445 |
+
self.num_res_blocks = num_res_blocks
|
446 |
+
self.resolution = resolution
|
447 |
+
self.in_channels = in_channels
|
448 |
+
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
449 |
+
|
450 |
+
if len(self.downsample_time_stride4_levels) > 0:
|
451 |
+
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
452 |
+
"The level to perform downsample 4 operation need to be smaller than "
|
453 |
+
"the total resolution number %s" % str(self.num_resolutions)
|
454 |
+
)
|
455 |
+
|
456 |
+
# downsampling
|
457 |
+
self.conv_in = torch.nn.Conv2d(
|
458 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
459 |
+
)
|
460 |
+
|
461 |
+
curr_res = resolution
|
462 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
463 |
+
self.in_ch_mult = in_ch_mult
|
464 |
+
self.down = nn.ModuleList()
|
465 |
+
for i_level in range(self.num_resolutions):
|
466 |
+
block = nn.ModuleList()
|
467 |
+
attn = nn.ModuleList()
|
468 |
+
block_in = ch * in_ch_mult[i_level]
|
469 |
+
block_out = ch * ch_mult[i_level]
|
470 |
+
for i_block in range(self.num_res_blocks):
|
471 |
+
block.append(
|
472 |
+
ResnetBlock(
|
473 |
+
in_channels=block_in,
|
474 |
+
out_channels=block_out,
|
475 |
+
temb_channels=self.temb_ch,
|
476 |
+
dropout=dropout,
|
477 |
+
)
|
478 |
+
)
|
479 |
+
block_in = block_out
|
480 |
+
if curr_res in attn_resolutions:
|
481 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
482 |
+
down = nn.Module()
|
483 |
+
down.block = block
|
484 |
+
down.attn = attn
|
485 |
+
if i_level != self.num_resolutions - 1:
|
486 |
+
if i_level in self.downsample_time_stride4_levels:
|
487 |
+
down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
|
488 |
+
else:
|
489 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
490 |
+
curr_res = curr_res // 2
|
491 |
+
self.down.append(down)
|
492 |
+
|
493 |
+
# middle
|
494 |
+
self.mid = nn.Module()
|
495 |
+
self.mid.block_1 = ResnetBlock(
|
496 |
+
in_channels=block_in,
|
497 |
+
out_channels=block_in,
|
498 |
+
temb_channels=self.temb_ch,
|
499 |
+
dropout=dropout,
|
500 |
+
)
|
501 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
502 |
+
self.mid.block_2 = ResnetBlock(
|
503 |
+
in_channels=block_in,
|
504 |
+
out_channels=block_in,
|
505 |
+
temb_channels=self.temb_ch,
|
506 |
+
dropout=dropout,
|
507 |
+
)
|
508 |
+
|
509 |
+
# end
|
510 |
+
self.norm_out = Normalize(block_in)
|
511 |
+
self.conv_out = torch.nn.Conv2d(
|
512 |
+
block_in,
|
513 |
+
2 * z_channels if double_z else z_channels,
|
514 |
+
kernel_size=3,
|
515 |
+
stride=1,
|
516 |
+
padding=1,
|
517 |
+
)
|
518 |
+
|
519 |
+
def forward(self, x):
|
520 |
+
# timestep embedding
|
521 |
+
temb = None
|
522 |
+
# downsampling
|
523 |
+
hs = [self.conv_in(x)]
|
524 |
+
for i_level in range(self.num_resolutions):
|
525 |
+
for i_block in range(self.num_res_blocks):
|
526 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
527 |
+
if len(self.down[i_level].attn) > 0:
|
528 |
+
h = self.down[i_level].attn[i_block](h)
|
529 |
+
hs.append(h)
|
530 |
+
if i_level != self.num_resolutions - 1:
|
531 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
532 |
+
|
533 |
+
# middle
|
534 |
+
h = hs[-1]
|
535 |
+
h = self.mid.block_1(h, temb)
|
536 |
+
h = self.mid.attn_1(h)
|
537 |
+
h = self.mid.block_2(h, temb)
|
538 |
+
|
539 |
+
# end
|
540 |
+
h = self.norm_out(h)
|
541 |
+
h = nonlinearity(h)
|
542 |
+
h = self.conv_out(h)
|
543 |
+
return h
|
544 |
+
|
545 |
+
|
546 |
+
class Decoder(nn.Module):
|
547 |
+
def __init__(
|
548 |
+
self,
|
549 |
+
*,
|
550 |
+
ch,
|
551 |
+
out_ch,
|
552 |
+
ch_mult=(1, 2, 4, 8),
|
553 |
+
num_res_blocks,
|
554 |
+
attn_resolutions,
|
555 |
+
dropout=0.0,
|
556 |
+
resamp_with_conv=True,
|
557 |
+
in_channels,
|
558 |
+
resolution,
|
559 |
+
z_channels,
|
560 |
+
give_pre_end=False,
|
561 |
+
tanh_out=False,
|
562 |
+
use_linear_attn=False,
|
563 |
+
downsample_time_stride4_levels=[],
|
564 |
+
attn_type="vanilla",
|
565 |
+
**ignorekwargs,
|
566 |
+
):
|
567 |
+
super().__init__()
|
568 |
+
if use_linear_attn:
|
569 |
+
attn_type = "linear"
|
570 |
+
self.ch = ch
|
571 |
+
self.temb_ch = 0
|
572 |
+
self.num_resolutions = len(ch_mult)
|
573 |
+
self.num_res_blocks = num_res_blocks
|
574 |
+
self.resolution = resolution
|
575 |
+
self.in_channels = in_channels
|
576 |
+
self.give_pre_end = give_pre_end
|
577 |
+
self.tanh_out = tanh_out
|
578 |
+
self.downsample_time_stride4_levels = downsample_time_stride4_levels
|
579 |
+
|
580 |
+
if len(self.downsample_time_stride4_levels) > 0:
|
581 |
+
assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
|
582 |
+
"The level to perform downsample 4 operation need to be smaller than "
|
583 |
+
"the total resolution number %s" % str(self.num_resolutions)
|
584 |
+
)
|
585 |
+
|
586 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
587 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
588 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
589 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
590 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
591 |
+
# print("Working with z of shape {} = {} dimensions.".format(
|
592 |
+
# self.z_shape, np.prod(self.z_shape)))
|
593 |
+
|
594 |
+
# z to block_in
|
595 |
+
self.conv_in = torch.nn.Conv2d(
|
596 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
597 |
+
)
|
598 |
+
|
599 |
+
# middle
|
600 |
+
self.mid = nn.Module()
|
601 |
+
self.mid.block_1 = ResnetBlock(
|
602 |
+
in_channels=block_in,
|
603 |
+
out_channels=block_in,
|
604 |
+
temb_channels=self.temb_ch,
|
605 |
+
dropout=dropout,
|
606 |
+
)
|
607 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
608 |
+
self.mid.block_2 = ResnetBlock(
|
609 |
+
in_channels=block_in,
|
610 |
+
out_channels=block_in,
|
611 |
+
temb_channels=self.temb_ch,
|
612 |
+
dropout=dropout,
|
613 |
+
)
|
614 |
+
|
615 |
+
# upsampling
|
616 |
+
self.up = nn.ModuleList()
|
617 |
+
for i_level in reversed(range(self.num_resolutions)):
|
618 |
+
block = nn.ModuleList()
|
619 |
+
attn = nn.ModuleList()
|
620 |
+
block_out = ch * ch_mult[i_level]
|
621 |
+
for i_block in range(self.num_res_blocks + 1):
|
622 |
+
block.append(
|
623 |
+
ResnetBlock(
|
624 |
+
in_channels=block_in,
|
625 |
+
out_channels=block_out,
|
626 |
+
temb_channels=self.temb_ch,
|
627 |
+
dropout=dropout,
|
628 |
+
)
|
629 |
+
)
|
630 |
+
block_in = block_out
|
631 |
+
if curr_res in attn_resolutions:
|
632 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
633 |
+
up = nn.Module()
|
634 |
+
up.block = block
|
635 |
+
up.attn = attn
|
636 |
+
if i_level != 0:
|
637 |
+
if i_level - 1 in self.downsample_time_stride4_levels:
|
638 |
+
up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
|
639 |
+
else:
|
640 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
641 |
+
curr_res = curr_res * 2
|
642 |
+
self.up.insert(0, up) # prepend to get consistent order
|
643 |
+
|
644 |
+
# end
|
645 |
+
self.norm_out = Normalize(block_in)
|
646 |
+
self.conv_out = torch.nn.Conv2d(
|
647 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
648 |
+
)
|
649 |
+
|
650 |
+
def forward(self, z):
|
651 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
652 |
+
self.last_z_shape = z.shape
|
653 |
+
|
654 |
+
# timestep embedding
|
655 |
+
temb = None
|
656 |
+
|
657 |
+
# z to block_in
|
658 |
+
h = self.conv_in(z)
|
659 |
+
|
660 |
+
# middle
|
661 |
+
h = self.mid.block_1(h, temb)
|
662 |
+
h = self.mid.attn_1(h)
|
663 |
+
h = self.mid.block_2(h, temb)
|
664 |
+
|
665 |
+
# upsampling
|
666 |
+
for i_level in reversed(range(self.num_resolutions)):
|
667 |
+
for i_block in range(self.num_res_blocks + 1):
|
668 |
+
h = self.up[i_level].block[i_block](h.float(), temb)
|
669 |
+
if len(self.up[i_level].attn) > 0:
|
670 |
+
h = self.up[i_level].attn[i_block](h.float())
|
671 |
+
if i_level != 0:
|
672 |
+
h = self.up[i_level].upsample(h.float())
|
673 |
+
|
674 |
+
# end
|
675 |
+
if self.give_pre_end:
|
676 |
+
return h
|
677 |
+
|
678 |
+
h = self.norm_out(h)
|
679 |
+
h = nonlinearity(h)
|
680 |
+
h = self.conv_out(h)
|
681 |
+
if self.tanh_out:
|
682 |
+
h = torch.tanh(h)
|
683 |
+
return h
|
684 |
+
|
685 |
+
|
686 |
+
class SimpleDecoder(nn.Module):
|
687 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
688 |
+
super().__init__()
|
689 |
+
self.model = nn.ModuleList(
|
690 |
+
[
|
691 |
+
nn.Conv2d(in_channels, in_channels, 1),
|
692 |
+
ResnetBlock(
|
693 |
+
in_channels=in_channels,
|
694 |
+
out_channels=2 * in_channels,
|
695 |
+
temb_channels=0,
|
696 |
+
dropout=0.0,
|
697 |
+
),
|
698 |
+
ResnetBlock(
|
699 |
+
in_channels=2 * in_channels,
|
700 |
+
out_channels=4 * in_channels,
|
701 |
+
temb_channels=0,
|
702 |
+
dropout=0.0,
|
703 |
+
),
|
704 |
+
ResnetBlock(
|
705 |
+
in_channels=4 * in_channels,
|
706 |
+
out_channels=2 * in_channels,
|
707 |
+
temb_channels=0,
|
708 |
+
dropout=0.0,
|
709 |
+
),
|
710 |
+
nn.Conv2d(2 * in_channels, in_channels, 1),
|
711 |
+
Upsample(in_channels, with_conv=True),
|
712 |
+
]
|
713 |
+
)
|
714 |
+
# end
|
715 |
+
self.norm_out = Normalize(in_channels)
|
716 |
+
self.conv_out = torch.nn.Conv2d(
|
717 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
718 |
+
)
|
719 |
+
|
720 |
+
def forward(self, x):
|
721 |
+
for i, layer in enumerate(self.model):
|
722 |
+
if i in [1, 2, 3]:
|
723 |
+
x = layer(x, None)
|
724 |
+
else:
|
725 |
+
x = layer(x)
|
726 |
+
|
727 |
+
h = self.norm_out(x)
|
728 |
+
h = nonlinearity(h)
|
729 |
+
x = self.conv_out(h)
|
730 |
+
return x
|
731 |
+
|
732 |
+
|
733 |
+
class UpsampleDecoder(nn.Module):
|
734 |
+
def __init__(
|
735 |
+
self,
|
736 |
+
in_channels,
|
737 |
+
out_channels,
|
738 |
+
ch,
|
739 |
+
num_res_blocks,
|
740 |
+
resolution,
|
741 |
+
ch_mult=(2, 2),
|
742 |
+
dropout=0.0,
|
743 |
+
):
|
744 |
+
super().__init__()
|
745 |
+
# upsampling
|
746 |
+
self.temb_ch = 0
|
747 |
+
self.num_resolutions = len(ch_mult)
|
748 |
+
self.num_res_blocks = num_res_blocks
|
749 |
+
block_in = in_channels
|
750 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
751 |
+
self.res_blocks = nn.ModuleList()
|
752 |
+
self.upsample_blocks = nn.ModuleList()
|
753 |
+
for i_level in range(self.num_resolutions):
|
754 |
+
res_block = []
|
755 |
+
block_out = ch * ch_mult[i_level]
|
756 |
+
for _ in range(self.num_res_blocks + 1):
|
757 |
+
res_block.append(
|
758 |
+
ResnetBlock(
|
759 |
+
in_channels=block_in,
|
760 |
+
out_channels=block_out,
|
761 |
+
temb_channels=self.temb_ch,
|
762 |
+
dropout=dropout,
|
763 |
+
)
|
764 |
+
)
|
765 |
+
block_in = block_out
|
766 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
767 |
+
if i_level != self.num_resolutions - 1:
|
768 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
769 |
+
curr_res = curr_res * 2
|
770 |
+
|
771 |
+
# end
|
772 |
+
self.norm_out = Normalize(block_in)
|
773 |
+
self.conv_out = torch.nn.Conv2d(
|
774 |
+
block_in, out_channels, kernel_size=3, stride=1, padding=1
|
775 |
+
)
|
776 |
+
|
777 |
+
def forward(self, x):
|
778 |
+
# upsampling
|
779 |
+
h = x
|
780 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
781 |
+
for i_block in range(self.num_res_blocks + 1):
|
782 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
783 |
+
if i_level != self.num_resolutions - 1:
|
784 |
+
h = self.upsample_blocks[k](h)
|
785 |
+
h = self.norm_out(h)
|
786 |
+
h = nonlinearity(h)
|
787 |
+
h = self.conv_out(h)
|
788 |
+
return h
|
789 |
+
|
790 |
+
|
791 |
+
class LatentRescaler(nn.Module):
|
792 |
+
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
793 |
+
super().__init__()
|
794 |
+
# residual block, interpolate, residual block
|
795 |
+
self.factor = factor
|
796 |
+
self.conv_in = nn.Conv2d(
|
797 |
+
in_channels, mid_channels, kernel_size=3, stride=1, padding=1
|
798 |
+
)
|
799 |
+
self.res_block1 = nn.ModuleList(
|
800 |
+
[
|
801 |
+
ResnetBlock(
|
802 |
+
in_channels=mid_channels,
|
803 |
+
out_channels=mid_channels,
|
804 |
+
temb_channels=0,
|
805 |
+
dropout=0.0,
|
806 |
+
)
|
807 |
+
for _ in range(depth)
|
808 |
+
]
|
809 |
+
)
|
810 |
+
self.attn = AttnBlock(mid_channels)
|
811 |
+
self.res_block2 = nn.ModuleList(
|
812 |
+
[
|
813 |
+
ResnetBlock(
|
814 |
+
in_channels=mid_channels,
|
815 |
+
out_channels=mid_channels,
|
816 |
+
temb_channels=0,
|
817 |
+
dropout=0.0,
|
818 |
+
)
|
819 |
+
for _ in range(depth)
|
820 |
+
]
|
821 |
+
)
|
822 |
+
|
823 |
+
self.conv_out = nn.Conv2d(
|
824 |
+
mid_channels,
|
825 |
+
out_channels,
|
826 |
+
kernel_size=1,
|
827 |
+
)
|
828 |
+
|
829 |
+
def forward(self, x):
|
830 |
+
x = self.conv_in(x)
|
831 |
+
for block in self.res_block1:
|
832 |
+
x = block(x, None)
|
833 |
+
x = torch.nn.functional.interpolate(
|
834 |
+
x,
|
835 |
+
size=(
|
836 |
+
int(round(x.shape[2] * self.factor)),
|
837 |
+
int(round(x.shape[3] * self.factor)),
|
838 |
+
),
|
839 |
+
)
|
840 |
+
x = self.attn(x).contiguous()
|
841 |
+
for block in self.res_block2:
|
842 |
+
x = block(x, None)
|
843 |
+
x = self.conv_out(x)
|
844 |
+
return x
|
845 |
+
|
846 |
+
|
847 |
+
class MergedRescaleEncoder(nn.Module):
|
848 |
+
def __init__(
|
849 |
+
self,
|
850 |
+
in_channels,
|
851 |
+
ch,
|
852 |
+
resolution,
|
853 |
+
out_ch,
|
854 |
+
num_res_blocks,
|
855 |
+
attn_resolutions,
|
856 |
+
dropout=0.0,
|
857 |
+
resamp_with_conv=True,
|
858 |
+
ch_mult=(1, 2, 4, 8),
|
859 |
+
rescale_factor=1.0,
|
860 |
+
rescale_module_depth=1,
|
861 |
+
):
|
862 |
+
super().__init__()
|
863 |
+
intermediate_chn = ch * ch_mult[-1]
|
864 |
+
self.encoder = Encoder(
|
865 |
+
in_channels=in_channels,
|
866 |
+
num_res_blocks=num_res_blocks,
|
867 |
+
ch=ch,
|
868 |
+
ch_mult=ch_mult,
|
869 |
+
z_channels=intermediate_chn,
|
870 |
+
double_z=False,
|
871 |
+
resolution=resolution,
|
872 |
+
attn_resolutions=attn_resolutions,
|
873 |
+
dropout=dropout,
|
874 |
+
resamp_with_conv=resamp_with_conv,
|
875 |
+
out_ch=None,
|
876 |
+
)
|
877 |
+
self.rescaler = LatentRescaler(
|
878 |
+
factor=rescale_factor,
|
879 |
+
in_channels=intermediate_chn,
|
880 |
+
mid_channels=intermediate_chn,
|
881 |
+
out_channels=out_ch,
|
882 |
+
depth=rescale_module_depth,
|
883 |
+
)
|
884 |
+
|
885 |
+
def forward(self, x):
|
886 |
+
x = self.encoder(x)
|
887 |
+
x = self.rescaler(x)
|
888 |
+
return x
|
889 |
+
|
890 |
+
|
891 |
+
class MergedRescaleDecoder(nn.Module):
|
892 |
+
def __init__(
|
893 |
+
self,
|
894 |
+
z_channels,
|
895 |
+
out_ch,
|
896 |
+
resolution,
|
897 |
+
num_res_blocks,
|
898 |
+
attn_resolutions,
|
899 |
+
ch,
|
900 |
+
ch_mult=(1, 2, 4, 8),
|
901 |
+
dropout=0.0,
|
902 |
+
resamp_with_conv=True,
|
903 |
+
rescale_factor=1.0,
|
904 |
+
rescale_module_depth=1,
|
905 |
+
):
|
906 |
+
super().__init__()
|
907 |
+
tmp_chn = z_channels * ch_mult[-1]
|
908 |
+
self.decoder = Decoder(
|
909 |
+
out_ch=out_ch,
|
910 |
+
z_channels=tmp_chn,
|
911 |
+
attn_resolutions=attn_resolutions,
|
912 |
+
dropout=dropout,
|
913 |
+
resamp_with_conv=resamp_with_conv,
|
914 |
+
in_channels=None,
|
915 |
+
num_res_blocks=num_res_blocks,
|
916 |
+
ch_mult=ch_mult,
|
917 |
+
resolution=resolution,
|
918 |
+
ch=ch,
|
919 |
+
)
|
920 |
+
self.rescaler = LatentRescaler(
|
921 |
+
factor=rescale_factor,
|
922 |
+
in_channels=z_channels,
|
923 |
+
mid_channels=tmp_chn,
|
924 |
+
out_channels=tmp_chn,
|
925 |
+
depth=rescale_module_depth,
|
926 |
+
)
|
927 |
+
|
928 |
+
def forward(self, x):
|
929 |
+
x = self.rescaler(x)
|
930 |
+
x = self.decoder(x)
|
931 |
+
return x
|
932 |
+
|
933 |
+
|
934 |
+
class Upsampler(nn.Module):
|
935 |
+
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
936 |
+
super().__init__()
|
937 |
+
assert out_size >= in_size
|
938 |
+
num_blocks = int(np.log2(out_size // in_size)) + 1
|
939 |
+
factor_up = 1.0 + (out_size % in_size)
|
940 |
+
print(
|
941 |
+
f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
|
942 |
+
)
|
943 |
+
self.rescaler = LatentRescaler(
|
944 |
+
factor=factor_up,
|
945 |
+
in_channels=in_channels,
|
946 |
+
mid_channels=2 * in_channels,
|
947 |
+
out_channels=in_channels,
|
948 |
+
)
|
949 |
+
self.decoder = Decoder(
|
950 |
+
out_ch=out_channels,
|
951 |
+
resolution=out_size,
|
952 |
+
z_channels=in_channels,
|
953 |
+
num_res_blocks=2,
|
954 |
+
attn_resolutions=[],
|
955 |
+
in_channels=None,
|
956 |
+
ch=in_channels,
|
957 |
+
ch_mult=[ch_mult for _ in range(num_blocks)],
|
958 |
+
)
|
959 |
+
|
960 |
+
def forward(self, x):
|
961 |
+
x = self.rescaler(x)
|
962 |
+
x = self.decoder(x)
|
963 |
+
return x
|
964 |
+
|
965 |
+
|
966 |
+
class Resize(nn.Module):
|
967 |
+
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
968 |
+
super().__init__()
|
969 |
+
self.with_conv = learned
|
970 |
+
self.mode = mode
|
971 |
+
if self.with_conv:
|
972 |
+
print(
|
973 |
+
f"Note: {self.__class__.__name} uses learned downsampling "
|
974 |
+
f"and will ignore the fixed {mode} mode"
|
975 |
+
)
|
976 |
+
raise NotImplementedError()
|
977 |
+
assert in_channels is not None
|
978 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
979 |
+
self.conv = torch.nn.Conv2d(
|
980 |
+
in_channels, in_channels, kernel_size=4, stride=2, padding=1
|
981 |
+
)
|
982 |
+
|
983 |
+
def forward(self, x, scale_factor=1.0):
|
984 |
+
if scale_factor == 1.0:
|
985 |
+
return x
|
986 |
+
else:
|
987 |
+
x = torch.nn.functional.interpolate(
|
988 |
+
x, mode=self.mode, align_corners=False, scale_factor=scale_factor
|
989 |
+
)
|
990 |
+
return x
|
991 |
+
|
992 |
+
|
993 |
+
class FirstStagePostProcessor(nn.Module):
|
994 |
+
def __init__(
|
995 |
+
self,
|
996 |
+
ch_mult: list,
|
997 |
+
in_channels,
|
998 |
+
pretrained_model: nn.Module = None,
|
999 |
+
reshape=False,
|
1000 |
+
n_channels=None,
|
1001 |
+
dropout=0.0,
|
1002 |
+
pretrained_config=None,
|
1003 |
+
):
|
1004 |
+
super().__init__()
|
1005 |
+
if pretrained_config is None:
|
1006 |
+
assert (
|
1007 |
+
pretrained_model is not None
|
1008 |
+
), 'Either "pretrained_model" or "pretrained_config" must not be None'
|
1009 |
+
self.pretrained_model = pretrained_model
|
1010 |
+
else:
|
1011 |
+
assert (
|
1012 |
+
pretrained_config is not None
|
1013 |
+
), 'Either "pretrained_model" or "pretrained_config" must not be None'
|
1014 |
+
self.instantiate_pretrained(pretrained_config)
|
1015 |
+
|
1016 |
+
self.do_reshape = reshape
|
1017 |
+
|
1018 |
+
if n_channels is None:
|
1019 |
+
n_channels = self.pretrained_model.encoder.ch
|
1020 |
+
|
1021 |
+
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
|
1022 |
+
self.proj = nn.Conv2d(
|
1023 |
+
in_channels, n_channels, kernel_size=3, stride=1, padding=1
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
blocks = []
|
1027 |
+
downs = []
|
1028 |
+
ch_in = n_channels
|
1029 |
+
for m in ch_mult:
|
1030 |
+
blocks.append(
|
1031 |
+
ResnetBlock(
|
1032 |
+
in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
|
1033 |
+
)
|
1034 |
+
)
|
1035 |
+
ch_in = m * n_channels
|
1036 |
+
downs.append(Downsample(ch_in, with_conv=False))
|
1037 |
+
|
1038 |
+
self.model = nn.ModuleList(blocks)
|
1039 |
+
self.downsampler = nn.ModuleList(downs)
|
1040 |
+
|
1041 |
+
def instantiate_pretrained(self, config):
|
1042 |
+
model = instantiate_from_config(config)
|
1043 |
+
self.pretrained_model = model.eval()
|
1044 |
+
# self.pretrained_model.train = False
|
1045 |
+
for param in self.pretrained_model.parameters():
|
1046 |
+
param.requires_grad = False
|
1047 |
+
|
1048 |
+
@torch.no_grad()
|
1049 |
+
def encode_with_pretrained(self, x):
|
1050 |
+
c = self.pretrained_model.encode(x)
|
1051 |
+
if isinstance(c, DiagonalGaussianDistribution):
|
1052 |
+
c = c.mode()
|
1053 |
+
return c
|
1054 |
+
|
1055 |
+
def forward(self, x):
|
1056 |
+
z_fs = self.encode_with_pretrained(x)
|
1057 |
+
z = self.proj_norm(z_fs)
|
1058 |
+
z = self.proj(z)
|
1059 |
+
z = nonlinearity(z)
|
1060 |
+
|
1061 |
+
for submodel, downmodel in zip(self.model, self.downsampler):
|
1062 |
+
z = submodel(z, temb=None)
|
1063 |
+
z = downmodel(z)
|
1064 |
+
|
1065 |
+
if self.do_reshape:
|
1066 |
+
z = rearrange(z, "b c h w -> b (h w) c")
|
1067 |
+
return z
|
consistencytta.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, Tensor
|
3 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
4 |
+
|
5 |
+
from diffusers.utils.torch_utils import randn_tensor
|
6 |
+
from diffusers import UNet2DConditionGuidedModel, HeunDiscreteScheduler
|
7 |
+
from audioldm.stft import TacotronSTFT
|
8 |
+
from audioldm.variational_autoencoder import AutoencoderKL
|
9 |
+
from audioldm.utils import default_audioldm_config
|
10 |
+
|
11 |
+
|
12 |
+
class ConsistencyTTA(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
# Initialize the consistency U-Net
|
18 |
+
unet_model_config_path='tango_diffusion_light.json'
|
19 |
+
unet_config = UNet2DConditionGuidedModel.load_config(unet_model_config_path)
|
20 |
+
self.unet = UNet2DConditionGuidedModel.from_config(unet_config, subfolder="unet")
|
21 |
+
|
22 |
+
unet_weight_path = "consistencytta_clapft_ckpt/unet_state_dict.pt"
|
23 |
+
unet_weight_sd = torch.load(unet_weight_path, map_location='cpu')
|
24 |
+
self.unet.load_state_dict(unet_weight_sd)
|
25 |
+
|
26 |
+
# Initialize FLAN-T5 tokenizer and text encoder
|
27 |
+
text_encoder_name = 'google/flan-t5-large'
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
|
29 |
+
self.text_encoder = T5EncoderModel.from_pretrained(text_encoder_name)
|
30 |
+
self.text_encoder.eval(); self.text_encoder.requires_grad_(False)
|
31 |
+
|
32 |
+
# Initialize the VAE
|
33 |
+
raw_vae_path = "consistencytta_clapft_ckpt/vae_state_dict.pt"
|
34 |
+
raw_vae_sd = torch.load(raw_vae_path, map_location="cpu")
|
35 |
+
vae_state_dict, scale_factor = raw_vae_sd["state_dict"], raw_vae_sd["scale_factor"]
|
36 |
+
|
37 |
+
config = default_audioldm_config('audioldm-s-full')
|
38 |
+
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
39 |
+
vae_config["scale_factor"] = scale_factor
|
40 |
+
|
41 |
+
self.vae = AutoencoderKL(**vae_config)
|
42 |
+
self.vae.load_state_dict(vae_state_dict)
|
43 |
+
self.vae.eval(); self.vae.requires_grad_(False)
|
44 |
+
|
45 |
+
# Initialize the STFT
|
46 |
+
self.fn_STFT = TacotronSTFT(
|
47 |
+
config["preprocessing"]["stft"]["filter_length"], # default 1024
|
48 |
+
config["preprocessing"]["stft"]["hop_length"], # default 160
|
49 |
+
config["preprocessing"]["stft"]["win_length"], # default 1024
|
50 |
+
config["preprocessing"]["mel"]["n_mel_channels"], # default 64
|
51 |
+
config["preprocessing"]["audio"]["sampling_rate"], # default 16000
|
52 |
+
config["preprocessing"]["mel"]["mel_fmin"], # default 0
|
53 |
+
config["preprocessing"]["mel"]["mel_fmax"], # default 8000
|
54 |
+
)
|
55 |
+
self.fn_STFT.eval(); self.fn_STFT.requires_grad_(False)
|
56 |
+
|
57 |
+
self.scheduler = HeunDiscreteScheduler.from_pretrained(
|
58 |
+
pretrained_model_name_or_path='stabilityai/stable-diffusion-2-1', subfolder="scheduler"
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def train(self, mode: bool = True):
|
63 |
+
self.unet.train(mode)
|
64 |
+
for model in [self.text_encoder, self.vae, self.fn_STFT]:
|
65 |
+
model.eval()
|
66 |
+
return self
|
67 |
+
|
68 |
+
|
69 |
+
def eval(self):
|
70 |
+
return self.train(mode=False)
|
71 |
+
|
72 |
+
|
73 |
+
def check_eval_mode(self):
|
74 |
+
for model, name in zip(
|
75 |
+
[self.text_encoder, self.vae, self.fn_STFT, self.unet],
|
76 |
+
['text_encoder', 'vae', 'fn_STFT', 'unet']
|
77 |
+
):
|
78 |
+
assert model.training == False, f"The {name} is not in eval mode."
|
79 |
+
for param in model.parameters():
|
80 |
+
assert param.requires_grad == False, f"The {name} is not frozen."
|
81 |
+
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def encode_text(self, prompt, max_length=None, padding=True):
|
85 |
+
device = self.text_encoder.device
|
86 |
+
if max_length is None:
|
87 |
+
max_length = self.tokenizer.model_max_length
|
88 |
+
|
89 |
+
batch = self.tokenizer(
|
90 |
+
prompt, max_length=max_length, padding=padding,
|
91 |
+
truncation=True, return_tensors="pt"
|
92 |
+
)
|
93 |
+
input_ids = batch.input_ids.to(device)
|
94 |
+
attention_mask = batch.attention_mask.to(device)
|
95 |
+
|
96 |
+
prompt_embeds = self.text_encoder(
|
97 |
+
input_ids=input_ids, attention_mask=attention_mask
|
98 |
+
)[0]
|
99 |
+
bool_prompt_mask = (attention_mask == 1).to(device) # Convert to boolean
|
100 |
+
return prompt_embeds, bool_prompt_mask
|
101 |
+
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def encode_text_classifier_free(self, prompt: str, num_samples_per_prompt: int):
|
105 |
+
# get conditional embeddings
|
106 |
+
cond_prompt_embeds, cond_prompt_mask = self.encode_text(prompt)
|
107 |
+
cond_prompt_embeds = cond_prompt_embeds.repeat_interleave(
|
108 |
+
num_samples_per_prompt, 0
|
109 |
+
)
|
110 |
+
cond_prompt_mask = cond_prompt_mask.repeat_interleave(
|
111 |
+
num_samples_per_prompt, 0
|
112 |
+
)
|
113 |
+
|
114 |
+
# get unconditional embeddings for classifier free guidance
|
115 |
+
uncond_tokens = [""] * len(prompt)
|
116 |
+
negative_prompt_embeds, uncond_prompt_mask = self.encode_text(
|
117 |
+
uncond_tokens, max_length=cond_prompt_embeds.shape[1], padding="max_length"
|
118 |
+
)
|
119 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(
|
120 |
+
num_samples_per_prompt, 0
|
121 |
+
)
|
122 |
+
uncond_prompt_mask = uncond_prompt_mask.repeat_interleave(
|
123 |
+
num_samples_per_prompt, 0
|
124 |
+
)
|
125 |
+
|
126 |
+
""" For classifier-free guidance, we need to do two forward passes.
|
127 |
+
We concatenate the unconditional and text embeddings into a single batch
|
128 |
+
"""
|
129 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, cond_prompt_embeds])
|
130 |
+
prompt_mask = torch.cat([uncond_prompt_mask, cond_prompt_mask])
|
131 |
+
|
132 |
+
return prompt_embeds, prompt_mask, cond_prompt_embeds, cond_prompt_mask
|
133 |
+
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self, prompt: str, cfg_scale_input: float = 3., cfg_scale_post: float = 1.,
|
137 |
+
num_steps: int = 1, num_samples: int = 1, sr: int = 16000
|
138 |
+
):
|
139 |
+
self.check_eval_mode()
|
140 |
+
device = self.text_encoder.device
|
141 |
+
use_cf_guidance = cfg_scale_post > 1.
|
142 |
+
|
143 |
+
# Get prompt embeddings
|
144 |
+
prompt_embeds_cf, prompt_mask_cf, prompt_embeds, prompt_mask = \
|
145 |
+
self.encode_text_classifier_free(prompt, num_samples)
|
146 |
+
encoder_states, encoder_att_mask = \
|
147 |
+
(prompt_embeds_cf, prompt_mask_cf) if use_cf_guidance \
|
148 |
+
else (prompt_embeds, prompt_mask)
|
149 |
+
|
150 |
+
# Prepare noise
|
151 |
+
num_channels_latents = self.unet.config.in_channels
|
152 |
+
latent_shape = (len(prompt) * num_samples, num_channels_latents, 256, 16)
|
153 |
+
noise = randn_tensor(
|
154 |
+
latent_shape, generator=None, device=device, dtype=prompt_embeds.dtype
|
155 |
+
)
|
156 |
+
|
157 |
+
# Query the inference scheduler to obtain the time steps.
|
158 |
+
# The time steps spread between 0 and training time steps
|
159 |
+
self.scheduler.set_timesteps(18, device=device) # Set this to training steps first
|
160 |
+
z_N = noise * self.scheduler.init_noise_sigma
|
161 |
+
|
162 |
+
def calc_zhat_0(z_n: Tensor, t: int):
|
163 |
+
""" Query the consistency model to get zhat_0, which is the denoised embedding.
|
164 |
+
Args:
|
165 |
+
z_n (Tensor): The noisy embedding.
|
166 |
+
t (int): The time step.
|
167 |
+
Returns:
|
168 |
+
Tensor: The denoised embedding.
|
169 |
+
"""
|
170 |
+
# expand the latents if we are doing classifier free guidance
|
171 |
+
z_n_input = torch.cat([z_n] * 2) if use_cf_guidance else z_n
|
172 |
+
# Scale model input as required for some schedules.
|
173 |
+
z_n_input = self.scheduler.scale_model_input(z_n_input, t)
|
174 |
+
|
175 |
+
# Get zhat_0 from the model
|
176 |
+
zhat_0 = self.unet(
|
177 |
+
z_n_input, t, guidance=cfg_scale_input,
|
178 |
+
encoder_hidden_states=encoder_states, encoder_attention_mask=encoder_att_mask
|
179 |
+
).sample
|
180 |
+
|
181 |
+
# Perform external classifier-free guidance
|
182 |
+
if use_cf_guidance:
|
183 |
+
zhat_0_uncond, zhat_0_cond = zhat_0.chunk(2)
|
184 |
+
zhat_0 = (1 - cfg_scale_post) * zhat_0_uncond + cfg_scale_post * zhat_0_cond
|
185 |
+
|
186 |
+
return zhat_0
|
187 |
+
|
188 |
+
# Query the consistency model
|
189 |
+
zhat_0 = calc_zhat_0(z_N, self.scheduler.timesteps[0])
|
190 |
+
|
191 |
+
# Iteratively query the consistency model if requested
|
192 |
+
self.scheduler.set_timesteps(num_steps, device=device)
|
193 |
+
|
194 |
+
for t in self.scheduler.timesteps[1::2]: # 2 is the order of the scheduler
|
195 |
+
zhat_n = self.scheduler.add_noise(zhat_0, torch.randn_like(zhat_0), t)
|
196 |
+
# Calculate new zhat_0
|
197 |
+
zhat_0 = calc_zhat_0(zhat_n, t)
|
198 |
+
|
199 |
+
mel = self.vae.decode_first_stage(zhat_0.float())
|
200 |
+
return self.vae.decode_to_waveform(mel)[:, :int(sr * 9.5)] # Truncate to 9.6 seconds
|
consistencytta_clapft_ckpt/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
diffusers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .scheduling_heun_discrete import HeunDiscreteScheduler
|
2 |
+
from .models.unet_2d_condition_guided import UNet2DConditionGuidedModel
|
diffusers/models/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from ..utils.import_utils import is_torch_available
|
16 |
+
|
17 |
+
|
18 |
+
if is_torch_available():
|
19 |
+
from .modeling_utils import ModelMixin
|
20 |
+
from .prior_transformer import PriorTransformer
|
21 |
+
from .unet_2d import UNet2DModel
|
22 |
+
from .unet_2d_condition import UNet2DConditionModel
|
23 |
+
from .unet_2d_condition_guided import UNet2DConditionGuidedModel
|
diffusers/models/activations.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation(act_fn):
|
5 |
+
if act_fn in ["swish", "silu"]:
|
6 |
+
return nn.SiLU()
|
7 |
+
elif act_fn == "mish":
|
8 |
+
return nn.Mish()
|
9 |
+
elif act_fn == "gelu":
|
10 |
+
return nn.GELU()
|
11 |
+
else:
|
12 |
+
raise ValueError(f"Unsupported activation function: {act_fn}")
|
diffusers/models/attention.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import math
|
15 |
+
from typing import Any, Callable, Dict, Optional
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
from ..utils.import_utils import is_xformers_available
|
22 |
+
from .attention_processor import Attention
|
23 |
+
from .embeddings import CombinedTimestepLabelEmbeddings
|
24 |
+
|
25 |
+
|
26 |
+
if is_xformers_available():
|
27 |
+
import xformers
|
28 |
+
import xformers.ops
|
29 |
+
else:
|
30 |
+
xformers = None
|
31 |
+
|
32 |
+
|
33 |
+
class AttentionBlock(nn.Module):
|
34 |
+
"""
|
35 |
+
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
36 |
+
to the N-d case.
|
37 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
38 |
+
Uses three q, k, v linear layers to compute attention.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
channels (`int`): The number of channels in the input and output.
|
42 |
+
num_head_channels (`int`, *optional*):
|
43 |
+
The number of channels in each head. If None, then `num_heads` = 1.
|
44 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
45 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
46 |
+
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
47 |
+
"""
|
48 |
+
|
49 |
+
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
channels: int,
|
54 |
+
num_head_channels: Optional[int] = None,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
rescale_output_factor: float = 1.0,
|
57 |
+
eps: float = 1e-5,
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
self.channels = channels
|
61 |
+
|
62 |
+
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
63 |
+
self.num_head_size = num_head_channels
|
64 |
+
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
65 |
+
|
66 |
+
# define q,k,v as linear layers
|
67 |
+
self.query = nn.Linear(channels, channels)
|
68 |
+
self.key = nn.Linear(channels, channels)
|
69 |
+
self.value = nn.Linear(channels, channels)
|
70 |
+
|
71 |
+
self.rescale_output_factor = rescale_output_factor
|
72 |
+
self.proj_attn = nn.Linear(channels, channels, bias=True)
|
73 |
+
|
74 |
+
self._use_memory_efficient_attention_xformers = False
|
75 |
+
self._attention_op = None
|
76 |
+
|
77 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
78 |
+
batch_size, seq_len, dim = tensor.shape
|
79 |
+
head_size = self.num_heads
|
80 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
81 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
82 |
+
return tensor
|
83 |
+
|
84 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
85 |
+
batch_size, seq_len, dim = tensor.shape
|
86 |
+
head_size = self.num_heads
|
87 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
88 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
89 |
+
return tensor
|
90 |
+
|
91 |
+
def set_use_memory_efficient_attention_xformers(
|
92 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
93 |
+
):
|
94 |
+
if use_memory_efficient_attention_xformers:
|
95 |
+
if not is_xformers_available():
|
96 |
+
raise ModuleNotFoundError(
|
97 |
+
(
|
98 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
99 |
+
" xformers"
|
100 |
+
),
|
101 |
+
name="xformers",
|
102 |
+
)
|
103 |
+
elif not torch.cuda.is_available():
|
104 |
+
raise ValueError(
|
105 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
106 |
+
" only available for GPU "
|
107 |
+
)
|
108 |
+
else:
|
109 |
+
try:
|
110 |
+
# Make sure we can run the memory efficient attention
|
111 |
+
_ = xformers.ops.memory_efficient_attention(
|
112 |
+
torch.randn((1, 2, 40), device="cuda"),
|
113 |
+
torch.randn((1, 2, 40), device="cuda"),
|
114 |
+
torch.randn((1, 2, 40), device="cuda"),
|
115 |
+
)
|
116 |
+
except Exception as e:
|
117 |
+
raise e
|
118 |
+
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
119 |
+
self._attention_op = attention_op
|
120 |
+
|
121 |
+
def forward(self, hidden_states):
|
122 |
+
residual = hidden_states
|
123 |
+
batch, channel, height, width = hidden_states.shape
|
124 |
+
|
125 |
+
# norm
|
126 |
+
hidden_states = self.group_norm(hidden_states)
|
127 |
+
|
128 |
+
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
129 |
+
|
130 |
+
# proj to q, k, v
|
131 |
+
query_proj = self.query(hidden_states)
|
132 |
+
key_proj = self.key(hidden_states)
|
133 |
+
value_proj = self.value(hidden_states)
|
134 |
+
|
135 |
+
scale = 1 / math.sqrt(self.channels / self.num_heads)
|
136 |
+
|
137 |
+
query_proj = self.reshape_heads_to_batch_dim(query_proj)
|
138 |
+
key_proj = self.reshape_heads_to_batch_dim(key_proj)
|
139 |
+
value_proj = self.reshape_heads_to_batch_dim(value_proj)
|
140 |
+
|
141 |
+
if self._use_memory_efficient_attention_xformers:
|
142 |
+
# Memory efficient attention
|
143 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
144 |
+
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
145 |
+
)
|
146 |
+
hidden_states = hidden_states.to(query_proj.dtype)
|
147 |
+
else:
|
148 |
+
attention_scores = torch.baddbmm(
|
149 |
+
torch.empty(
|
150 |
+
query_proj.shape[0],
|
151 |
+
query_proj.shape[1],
|
152 |
+
key_proj.shape[1],
|
153 |
+
dtype=query_proj.dtype,
|
154 |
+
device=query_proj.device,
|
155 |
+
),
|
156 |
+
query_proj,
|
157 |
+
key_proj.transpose(-1, -2),
|
158 |
+
beta=0,
|
159 |
+
alpha=scale,
|
160 |
+
)
|
161 |
+
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
162 |
+
hidden_states = torch.bmm(attention_probs, value_proj)
|
163 |
+
|
164 |
+
# reshape hidden_states
|
165 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
166 |
+
|
167 |
+
# compute next hidden_states
|
168 |
+
hidden_states = self.proj_attn(hidden_states)
|
169 |
+
|
170 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
171 |
+
|
172 |
+
# res connect and rescale
|
173 |
+
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
174 |
+
return hidden_states
|
175 |
+
|
176 |
+
|
177 |
+
class BasicTransformerBlock(nn.Module):
|
178 |
+
r"""
|
179 |
+
A basic Transformer block.
|
180 |
+
|
181 |
+
Parameters:
|
182 |
+
dim (`int`): The number of channels in the input and output.
|
183 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
184 |
+
attention_head_dim (`int`): The number of channels in each head.
|
185 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
186 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
187 |
+
only_cross_attention (`bool`, *optional*):
|
188 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
189 |
+
double_self_attention (`bool`, *optional*):
|
190 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
191 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
192 |
+
num_embeds_ada_norm (:
|
193 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
194 |
+
attention_bias (:
|
195 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
dim: int,
|
201 |
+
num_attention_heads: int,
|
202 |
+
attention_head_dim: int,
|
203 |
+
dropout=0.0,
|
204 |
+
cross_attention_dim: Optional[int] = None,
|
205 |
+
activation_fn: str = "geglu",
|
206 |
+
num_embeds_ada_norm: Optional[int] = None,
|
207 |
+
attention_bias: bool = False,
|
208 |
+
only_cross_attention: bool = False,
|
209 |
+
double_self_attention: bool = False,
|
210 |
+
upcast_attention: bool = False,
|
211 |
+
norm_elementwise_affine: bool = True,
|
212 |
+
norm_type: str = "layer_norm",
|
213 |
+
final_dropout: bool = False,
|
214 |
+
):
|
215 |
+
super().__init__()
|
216 |
+
self.only_cross_attention = only_cross_attention
|
217 |
+
|
218 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
219 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
220 |
+
|
221 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
222 |
+
raise ValueError(
|
223 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
224 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
225 |
+
)
|
226 |
+
|
227 |
+
# 1. Self-Attn
|
228 |
+
self.attn1 = Attention(
|
229 |
+
query_dim=dim,
|
230 |
+
heads=num_attention_heads,
|
231 |
+
dim_head=attention_head_dim,
|
232 |
+
dropout=dropout,
|
233 |
+
bias=attention_bias,
|
234 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
235 |
+
upcast_attention=upcast_attention,
|
236 |
+
)
|
237 |
+
|
238 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
239 |
+
|
240 |
+
# 2. Cross-Attn
|
241 |
+
if cross_attention_dim is not None or double_self_attention:
|
242 |
+
self.attn2 = Attention(
|
243 |
+
query_dim=dim,
|
244 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
245 |
+
heads=num_attention_heads,
|
246 |
+
dim_head=attention_head_dim,
|
247 |
+
dropout=dropout,
|
248 |
+
bias=attention_bias,
|
249 |
+
upcast_attention=upcast_attention,
|
250 |
+
) # is self-attn if encoder_hidden_states is none
|
251 |
+
else:
|
252 |
+
self.attn2 = None
|
253 |
+
|
254 |
+
if self.use_ada_layer_norm:
|
255 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
256 |
+
elif self.use_ada_layer_norm_zero:
|
257 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
258 |
+
else:
|
259 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
260 |
+
|
261 |
+
if cross_attention_dim is not None or double_self_attention:
|
262 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
263 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
264 |
+
# the second cross attention block.
|
265 |
+
self.norm2 = (
|
266 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
267 |
+
if self.use_ada_layer_norm
|
268 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
self.norm2 = None
|
272 |
+
|
273 |
+
# 3. Feed-forward
|
274 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
275 |
+
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
hidden_states: torch.FloatTensor,
|
279 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
280 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
281 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
282 |
+
timestep: Optional[torch.LongTensor] = None,
|
283 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
284 |
+
class_labels: Optional[torch.LongTensor] = None,
|
285 |
+
):
|
286 |
+
if self.use_ada_layer_norm:
|
287 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
288 |
+
elif self.use_ada_layer_norm_zero:
|
289 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
290 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
norm_hidden_states = self.norm1(hidden_states)
|
294 |
+
|
295 |
+
# 1. Self-Attention
|
296 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
297 |
+
attn_output = self.attn1(
|
298 |
+
norm_hidden_states,
|
299 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
300 |
+
attention_mask=attention_mask,
|
301 |
+
**cross_attention_kwargs,
|
302 |
+
)
|
303 |
+
if self.use_ada_layer_norm_zero:
|
304 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
305 |
+
hidden_states = attn_output + hidden_states
|
306 |
+
|
307 |
+
if self.attn2 is not None:
|
308 |
+
norm_hidden_states = (
|
309 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
310 |
+
)
|
311 |
+
|
312 |
+
# 2. Cross-Attention
|
313 |
+
attn_output = self.attn2(
|
314 |
+
norm_hidden_states,
|
315 |
+
encoder_hidden_states=encoder_hidden_states,
|
316 |
+
attention_mask=encoder_attention_mask,
|
317 |
+
**cross_attention_kwargs,
|
318 |
+
)
|
319 |
+
hidden_states = attn_output + hidden_states
|
320 |
+
|
321 |
+
# 3. Feed-forward
|
322 |
+
norm_hidden_states = self.norm3(hidden_states)
|
323 |
+
|
324 |
+
if self.use_ada_layer_norm_zero:
|
325 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
326 |
+
|
327 |
+
ff_output = self.ff(norm_hidden_states)
|
328 |
+
|
329 |
+
if self.use_ada_layer_norm_zero:
|
330 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
331 |
+
|
332 |
+
hidden_states = ff_output + hidden_states
|
333 |
+
|
334 |
+
return hidden_states
|
335 |
+
|
336 |
+
|
337 |
+
class FeedForward(nn.Module):
|
338 |
+
r"""
|
339 |
+
A feed-forward layer.
|
340 |
+
|
341 |
+
Parameters:
|
342 |
+
dim (`int`): The number of channels in the input.
|
343 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
344 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
345 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
346 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
347 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
348 |
+
"""
|
349 |
+
|
350 |
+
def __init__(
|
351 |
+
self,
|
352 |
+
dim: int,
|
353 |
+
dim_out: Optional[int] = None,
|
354 |
+
mult: int = 4,
|
355 |
+
dropout: float = 0.0,
|
356 |
+
activation_fn: str = "geglu",
|
357 |
+
final_dropout: bool = False,
|
358 |
+
):
|
359 |
+
super().__init__()
|
360 |
+
inner_dim = int(dim * mult)
|
361 |
+
dim_out = dim_out if dim_out is not None else dim
|
362 |
+
|
363 |
+
if activation_fn == "gelu":
|
364 |
+
act_fn = GELU(dim, inner_dim)
|
365 |
+
if activation_fn == "gelu-approximate":
|
366 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
367 |
+
elif activation_fn == "geglu":
|
368 |
+
act_fn = GEGLU(dim, inner_dim)
|
369 |
+
elif activation_fn == "geglu-approximate":
|
370 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
371 |
+
|
372 |
+
self.net = nn.ModuleList([])
|
373 |
+
# project in
|
374 |
+
self.net.append(act_fn)
|
375 |
+
# project dropout
|
376 |
+
self.net.append(nn.Dropout(dropout))
|
377 |
+
# project out
|
378 |
+
self.net.append(nn.Linear(inner_dim, dim_out))
|
379 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
380 |
+
if final_dropout:
|
381 |
+
self.net.append(nn.Dropout(dropout))
|
382 |
+
|
383 |
+
def forward(self, hidden_states):
|
384 |
+
for module in self.net:
|
385 |
+
hidden_states = module(hidden_states)
|
386 |
+
return hidden_states
|
387 |
+
|
388 |
+
|
389 |
+
class GELU(nn.Module):
|
390 |
+
r"""
|
391 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
392 |
+
"""
|
393 |
+
|
394 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
395 |
+
super().__init__()
|
396 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
397 |
+
self.approximate = approximate
|
398 |
+
|
399 |
+
def gelu(self, gate):
|
400 |
+
if gate.device.type != "mps":
|
401 |
+
return F.gelu(gate, approximate=self.approximate)
|
402 |
+
# mps: gelu is not implemented for float16
|
403 |
+
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
404 |
+
|
405 |
+
def forward(self, hidden_states):
|
406 |
+
hidden_states = self.proj(hidden_states)
|
407 |
+
hidden_states = self.gelu(hidden_states)
|
408 |
+
return hidden_states
|
409 |
+
|
410 |
+
|
411 |
+
class GEGLU(nn.Module):
|
412 |
+
r"""
|
413 |
+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
414 |
+
|
415 |
+
Parameters:
|
416 |
+
dim_in (`int`): The number of channels in the input.
|
417 |
+
dim_out (`int`): The number of channels in the output.
|
418 |
+
"""
|
419 |
+
|
420 |
+
def __init__(self, dim_in: int, dim_out: int):
|
421 |
+
super().__init__()
|
422 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
423 |
+
|
424 |
+
def gelu(self, gate):
|
425 |
+
if gate.device.type != "mps":
|
426 |
+
return F.gelu(gate)
|
427 |
+
# mps: gelu is not implemented for float16
|
428 |
+
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
429 |
+
|
430 |
+
def forward(self, hidden_states):
|
431 |
+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
432 |
+
return hidden_states * self.gelu(gate)
|
433 |
+
|
434 |
+
|
435 |
+
class ApproximateGELU(nn.Module):
|
436 |
+
"""
|
437 |
+
The approximate form of Gaussian Error Linear Unit (GELU)
|
438 |
+
|
439 |
+
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
440 |
+
"""
|
441 |
+
|
442 |
+
def __init__(self, dim_in: int, dim_out: int):
|
443 |
+
super().__init__()
|
444 |
+
self.proj = nn.Linear(dim_in, dim_out)
|
445 |
+
|
446 |
+
def forward(self, x):
|
447 |
+
x = self.proj(x)
|
448 |
+
return x * torch.sigmoid(1.702 * x)
|
449 |
+
|
450 |
+
|
451 |
+
class AdaLayerNorm(nn.Module):
|
452 |
+
"""
|
453 |
+
Norm layer modified to incorporate timestep embeddings.
|
454 |
+
"""
|
455 |
+
|
456 |
+
def __init__(self, embedding_dim, num_embeddings):
|
457 |
+
super().__init__()
|
458 |
+
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
459 |
+
self.silu = nn.SiLU()
|
460 |
+
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
461 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
462 |
+
|
463 |
+
def forward(self, x, timestep):
|
464 |
+
emb = self.linear(self.silu(self.emb(timestep)))
|
465 |
+
scale, shift = torch.chunk(emb, 2)
|
466 |
+
x = self.norm(x) * (1 + scale) + shift
|
467 |
+
return x
|
468 |
+
|
469 |
+
|
470 |
+
class AdaLayerNormZero(nn.Module):
|
471 |
+
"""
|
472 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
473 |
+
"""
|
474 |
+
|
475 |
+
def __init__(self, embedding_dim, num_embeddings):
|
476 |
+
super().__init__()
|
477 |
+
|
478 |
+
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
479 |
+
|
480 |
+
self.silu = nn.SiLU()
|
481 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
482 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
483 |
+
|
484 |
+
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
485 |
+
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
486 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
487 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
488 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
489 |
+
|
490 |
+
|
491 |
+
class AdaGroupNorm(nn.Module):
|
492 |
+
"""
|
493 |
+
GroupNorm layer modified to incorporate timestep embeddings.
|
494 |
+
"""
|
495 |
+
|
496 |
+
def __init__(
|
497 |
+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
|
498 |
+
):
|
499 |
+
super().__init__()
|
500 |
+
self.num_groups = num_groups
|
501 |
+
self.eps = eps
|
502 |
+
self.act = None
|
503 |
+
if act_fn == "swish":
|
504 |
+
self.act = lambda x: F.silu(x)
|
505 |
+
elif act_fn == "mish":
|
506 |
+
self.act = nn.Mish()
|
507 |
+
elif act_fn == "silu":
|
508 |
+
self.act = nn.SiLU()
|
509 |
+
elif act_fn == "gelu":
|
510 |
+
self.act = nn.GELU()
|
511 |
+
|
512 |
+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
513 |
+
|
514 |
+
def forward(self, x, emb):
|
515 |
+
if self.act:
|
516 |
+
emb = self.act(emb)
|
517 |
+
emb = self.linear(emb)
|
518 |
+
emb = emb[:, :, None, None]
|
519 |
+
scale, shift = emb.chunk(2, dim=1)
|
520 |
+
|
521 |
+
x = F.group_norm(x, self.num_groups, eps=self.eps)
|
522 |
+
x = x * (1 + scale) + shift
|
523 |
+
return x
|
diffusers/models/attention_processor.py
ADDED
@@ -0,0 +1,1646 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Callable, Optional, Union
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from ..utils.deprecation_utils import deprecate
|
21 |
+
from ..utils.torch_utils import maybe_allow_in_graph
|
22 |
+
from ..utils.import_utils import is_xformers_available
|
23 |
+
from ..utils.logging import get_logger
|
24 |
+
|
25 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
26 |
+
|
27 |
+
|
28 |
+
if is_xformers_available():
|
29 |
+
import xformers
|
30 |
+
import xformers.ops
|
31 |
+
else:
|
32 |
+
xformers = None
|
33 |
+
|
34 |
+
|
35 |
+
@maybe_allow_in_graph
|
36 |
+
class Attention(nn.Module):
|
37 |
+
r"""
|
38 |
+
A cross attention layer.
|
39 |
+
|
40 |
+
Parameters:
|
41 |
+
query_dim (`int`): The number of channels in the query.
|
42 |
+
cross_attention_dim (`int`, *optional*):
|
43 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
44 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
45 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
46 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
47 |
+
bias (`bool`, *optional*, defaults to False):
|
48 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
query_dim: int,
|
54 |
+
cross_attention_dim: Optional[int] = None,
|
55 |
+
heads: int = 8,
|
56 |
+
dim_head: int = 64,
|
57 |
+
dropout: float = 0.0,
|
58 |
+
bias=False,
|
59 |
+
upcast_attention: bool = False,
|
60 |
+
upcast_softmax: bool = False,
|
61 |
+
cross_attention_norm: Optional[str] = None,
|
62 |
+
cross_attention_norm_num_groups: int = 32,
|
63 |
+
added_kv_proj_dim: Optional[int] = None,
|
64 |
+
norm_num_groups: Optional[int] = None,
|
65 |
+
spatial_norm_dim: Optional[int] = None,
|
66 |
+
out_bias: bool = True,
|
67 |
+
scale_qk: bool = True,
|
68 |
+
only_cross_attention: bool = False,
|
69 |
+
eps: float = 1e-5,
|
70 |
+
rescale_output_factor: float = 1.0,
|
71 |
+
residual_connection: bool = False,
|
72 |
+
_from_deprecated_attn_block=False,
|
73 |
+
processor: Optional["AttnProcessor"] = None,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
inner_dim = dim_head * heads
|
77 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
78 |
+
self.upcast_attention = upcast_attention
|
79 |
+
self.upcast_softmax = upcast_softmax
|
80 |
+
self.rescale_output_factor = rescale_output_factor
|
81 |
+
self.residual_connection = residual_connection
|
82 |
+
self.dropout = dropout
|
83 |
+
|
84 |
+
# we make use of this private variable to know whether this class is loaded
|
85 |
+
# with an deprecated state dict so that we can convert it on the fly
|
86 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
87 |
+
|
88 |
+
self.scale_qk = scale_qk
|
89 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
90 |
+
|
91 |
+
self.heads = heads
|
92 |
+
# for slice_size > 0 the attention score computation
|
93 |
+
# is split across the batch axis to save memory
|
94 |
+
# You can set slice_size with `set_attention_slice`
|
95 |
+
self.sliceable_head_dim = heads
|
96 |
+
|
97 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
98 |
+
self.only_cross_attention = only_cross_attention
|
99 |
+
|
100 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
101 |
+
raise ValueError(
|
102 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
103 |
+
)
|
104 |
+
|
105 |
+
if norm_num_groups is not None:
|
106 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
107 |
+
else:
|
108 |
+
self.group_norm = None
|
109 |
+
|
110 |
+
if spatial_norm_dim is not None:
|
111 |
+
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
112 |
+
else:
|
113 |
+
self.spatial_norm = None
|
114 |
+
|
115 |
+
if cross_attention_norm is None:
|
116 |
+
self.norm_cross = None
|
117 |
+
elif cross_attention_norm == "layer_norm":
|
118 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
119 |
+
elif cross_attention_norm == "group_norm":
|
120 |
+
if self.added_kv_proj_dim is not None:
|
121 |
+
# The given `encoder_hidden_states` are initially of shape
|
122 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
123 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
124 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
125 |
+
# the number of channels for the group norm.
|
126 |
+
norm_cross_num_channels = added_kv_proj_dim
|
127 |
+
else:
|
128 |
+
norm_cross_num_channels = cross_attention_dim
|
129 |
+
|
130 |
+
self.norm_cross = nn.GroupNorm(
|
131 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
raise ValueError(
|
135 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
136 |
+
)
|
137 |
+
|
138 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
139 |
+
|
140 |
+
if not self.only_cross_attention:
|
141 |
+
# only relevant for the `AddedKVProcessor` classes
|
142 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
143 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
144 |
+
else:
|
145 |
+
self.to_k = None
|
146 |
+
self.to_v = None
|
147 |
+
|
148 |
+
if self.added_kv_proj_dim is not None:
|
149 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
150 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
151 |
+
|
152 |
+
self.to_out = nn.ModuleList([])
|
153 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
154 |
+
self.to_out.append(nn.Dropout(dropout))
|
155 |
+
|
156 |
+
# set attention processor
|
157 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
158 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
159 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
160 |
+
if processor is None:
|
161 |
+
processor = (
|
162 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
163 |
+
)
|
164 |
+
self.set_processor(processor)
|
165 |
+
|
166 |
+
def set_use_memory_efficient_attention_xformers(
|
167 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
168 |
+
):
|
169 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
170 |
+
self.processor,
|
171 |
+
(LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor),
|
172 |
+
)
|
173 |
+
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
174 |
+
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
175 |
+
)
|
176 |
+
is_added_kv_processor = hasattr(self, "processor") and isinstance(
|
177 |
+
self.processor,
|
178 |
+
(
|
179 |
+
AttnAddedKVProcessor,
|
180 |
+
AttnAddedKVProcessor2_0,
|
181 |
+
SlicedAttnAddedKVProcessor,
|
182 |
+
XFormersAttnAddedKVProcessor,
|
183 |
+
LoRAAttnAddedKVProcessor,
|
184 |
+
),
|
185 |
+
)
|
186 |
+
|
187 |
+
if use_memory_efficient_attention_xformers:
|
188 |
+
if is_added_kv_processor and (is_lora or is_custom_diffusion):
|
189 |
+
raise NotImplementedError(
|
190 |
+
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
|
191 |
+
)
|
192 |
+
if not is_xformers_available():
|
193 |
+
raise ModuleNotFoundError(
|
194 |
+
(
|
195 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
196 |
+
" xformers"
|
197 |
+
),
|
198 |
+
name="xformers",
|
199 |
+
)
|
200 |
+
elif not torch.cuda.is_available():
|
201 |
+
raise ValueError(
|
202 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
203 |
+
" only available for GPU "
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
try:
|
207 |
+
# Make sure we can run the memory efficient attention
|
208 |
+
_ = xformers.ops.memory_efficient_attention(
|
209 |
+
torch.randn((1, 2, 40), device="cuda"),
|
210 |
+
torch.randn((1, 2, 40), device="cuda"),
|
211 |
+
torch.randn((1, 2, 40), device="cuda"),
|
212 |
+
)
|
213 |
+
except Exception as e:
|
214 |
+
raise e
|
215 |
+
|
216 |
+
if is_lora:
|
217 |
+
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
|
218 |
+
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
|
219 |
+
processor = LoRAXFormersAttnProcessor(
|
220 |
+
hidden_size=self.processor.hidden_size,
|
221 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
222 |
+
rank=self.processor.rank,
|
223 |
+
attention_op=attention_op,
|
224 |
+
)
|
225 |
+
processor.load_state_dict(self.processor.state_dict())
|
226 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
227 |
+
elif is_custom_diffusion:
|
228 |
+
processor = CustomDiffusionXFormersAttnProcessor(
|
229 |
+
train_kv=self.processor.train_kv,
|
230 |
+
train_q_out=self.processor.train_q_out,
|
231 |
+
hidden_size=self.processor.hidden_size,
|
232 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
233 |
+
attention_op=attention_op,
|
234 |
+
)
|
235 |
+
processor.load_state_dict(self.processor.state_dict())
|
236 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
237 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
238 |
+
elif is_added_kv_processor:
|
239 |
+
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
240 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
241 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
242 |
+
# throw warning
|
243 |
+
logger.info(
|
244 |
+
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
245 |
+
)
|
246 |
+
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
247 |
+
else:
|
248 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
249 |
+
else:
|
250 |
+
if is_lora:
|
251 |
+
attn_processor_class = (
|
252 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
253 |
+
)
|
254 |
+
processor = attn_processor_class(
|
255 |
+
hidden_size=self.processor.hidden_size,
|
256 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
257 |
+
rank=self.processor.rank,
|
258 |
+
)
|
259 |
+
processor.load_state_dict(self.processor.state_dict())
|
260 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
261 |
+
elif is_custom_diffusion:
|
262 |
+
processor = CustomDiffusionAttnProcessor(
|
263 |
+
train_kv=self.processor.train_kv,
|
264 |
+
train_q_out=self.processor.train_q_out,
|
265 |
+
hidden_size=self.processor.hidden_size,
|
266 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
267 |
+
)
|
268 |
+
processor.load_state_dict(self.processor.state_dict())
|
269 |
+
if hasattr(self.processor, "to_k_custom_diffusion"):
|
270 |
+
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
271 |
+
else:
|
272 |
+
# set attention processor
|
273 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
274 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
275 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
276 |
+
processor = (
|
277 |
+
AttnProcessor2_0()
|
278 |
+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
279 |
+
else AttnProcessor()
|
280 |
+
)
|
281 |
+
|
282 |
+
self.set_processor(processor)
|
283 |
+
|
284 |
+
def set_attention_slice(self, slice_size):
|
285 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
286 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
287 |
+
|
288 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
289 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
290 |
+
elif slice_size is not None:
|
291 |
+
processor = SlicedAttnProcessor(slice_size)
|
292 |
+
elif self.added_kv_proj_dim is not None:
|
293 |
+
processor = AttnAddedKVProcessor()
|
294 |
+
else:
|
295 |
+
# set attention processor
|
296 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
297 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
298 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
299 |
+
processor = (
|
300 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
301 |
+
)
|
302 |
+
|
303 |
+
self.set_processor(processor)
|
304 |
+
|
305 |
+
def set_processor(self, processor: "AttnProcessor"):
|
306 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
307 |
+
# pop `processor` from `self._modules`
|
308 |
+
if (
|
309 |
+
hasattr(self, "processor")
|
310 |
+
and isinstance(self.processor, torch.nn.Module)
|
311 |
+
and not isinstance(processor, torch.nn.Module)
|
312 |
+
):
|
313 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
314 |
+
self._modules.pop("processor")
|
315 |
+
|
316 |
+
self.processor = processor
|
317 |
+
|
318 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
319 |
+
# The `Attention` class can call different attention processors / attention functions
|
320 |
+
# here we simply pass along all tensors to the selected processor class
|
321 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
322 |
+
return self.processor(
|
323 |
+
self,
|
324 |
+
hidden_states,
|
325 |
+
encoder_hidden_states=encoder_hidden_states,
|
326 |
+
attention_mask=attention_mask,
|
327 |
+
**cross_attention_kwargs,
|
328 |
+
)
|
329 |
+
|
330 |
+
def batch_to_head_dim(self, tensor):
|
331 |
+
head_size = self.heads
|
332 |
+
batch_size, seq_len, dim = tensor.shape
|
333 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
334 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
335 |
+
return tensor
|
336 |
+
|
337 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
338 |
+
head_size = self.heads
|
339 |
+
batch_size, seq_len, dim = tensor.shape
|
340 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
341 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
342 |
+
|
343 |
+
if out_dim == 3:
|
344 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
345 |
+
|
346 |
+
return tensor
|
347 |
+
|
348 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
349 |
+
dtype = query.dtype
|
350 |
+
if self.upcast_attention:
|
351 |
+
query = query.float()
|
352 |
+
key = key.float()
|
353 |
+
|
354 |
+
if attention_mask is None:
|
355 |
+
baddbmm_input = torch.empty(
|
356 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
357 |
+
)
|
358 |
+
beta = 0
|
359 |
+
else:
|
360 |
+
baddbmm_input = attention_mask
|
361 |
+
beta = 1
|
362 |
+
|
363 |
+
attention_scores = torch.baddbmm(
|
364 |
+
baddbmm_input,
|
365 |
+
query,
|
366 |
+
key.transpose(-1, -2),
|
367 |
+
beta=beta,
|
368 |
+
alpha=self.scale,
|
369 |
+
)
|
370 |
+
del baddbmm_input
|
371 |
+
|
372 |
+
if self.upcast_softmax:
|
373 |
+
attention_scores = attention_scores.float()
|
374 |
+
|
375 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
376 |
+
del attention_scores
|
377 |
+
|
378 |
+
attention_probs = attention_probs.to(dtype)
|
379 |
+
|
380 |
+
return attention_probs
|
381 |
+
|
382 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
383 |
+
if batch_size is None:
|
384 |
+
deprecate(
|
385 |
+
"batch_size=None",
|
386 |
+
"0.0.15",
|
387 |
+
(
|
388 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
389 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
390 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
391 |
+
),
|
392 |
+
)
|
393 |
+
batch_size = 1
|
394 |
+
|
395 |
+
head_size = self.heads
|
396 |
+
if attention_mask is None:
|
397 |
+
return attention_mask
|
398 |
+
|
399 |
+
current_length: int = attention_mask.shape[-1]
|
400 |
+
if current_length != target_length:
|
401 |
+
if attention_mask.device.type == "mps":
|
402 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
403 |
+
# Instead, we can manually construct the padding tensor.
|
404 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
405 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
406 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
407 |
+
else:
|
408 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
409 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
410 |
+
# remaining_length: int = target_length - current_length
|
411 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
412 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
413 |
+
|
414 |
+
if out_dim == 3:
|
415 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
416 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
417 |
+
elif out_dim == 4:
|
418 |
+
attention_mask = attention_mask.unsqueeze(1)
|
419 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
420 |
+
|
421 |
+
return attention_mask
|
422 |
+
|
423 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
424 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
425 |
+
|
426 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
427 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
428 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
429 |
+
# Group norm norms along the channels dimension and expects
|
430 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
431 |
+
# to norm along the hidden dimension, so we need to move
|
432 |
+
# (batch_size, sequence_length, hidden_size) ->
|
433 |
+
# (batch_size, hidden_size, sequence_length)
|
434 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
435 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
436 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
437 |
+
else:
|
438 |
+
assert False
|
439 |
+
|
440 |
+
return encoder_hidden_states
|
441 |
+
|
442 |
+
|
443 |
+
class AttnProcessor:
|
444 |
+
r"""
|
445 |
+
Default processor for performing attention-related computations.
|
446 |
+
"""
|
447 |
+
|
448 |
+
def __call__(
|
449 |
+
self,
|
450 |
+
attn: Attention,
|
451 |
+
hidden_states,
|
452 |
+
encoder_hidden_states=None,
|
453 |
+
attention_mask=None,
|
454 |
+
temb=None,
|
455 |
+
):
|
456 |
+
residual = hidden_states
|
457 |
+
|
458 |
+
if attn.spatial_norm is not None:
|
459 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
460 |
+
|
461 |
+
input_ndim = hidden_states.ndim
|
462 |
+
|
463 |
+
if input_ndim == 4:
|
464 |
+
batch_size, channel, height, width = hidden_states.shape
|
465 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
466 |
+
|
467 |
+
batch_size, sequence_length, _ = (
|
468 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
469 |
+
)
|
470 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
471 |
+
|
472 |
+
if attn.group_norm is not None:
|
473 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
474 |
+
|
475 |
+
query = attn.to_q(hidden_states)
|
476 |
+
|
477 |
+
if encoder_hidden_states is None:
|
478 |
+
encoder_hidden_states = hidden_states
|
479 |
+
elif attn.norm_cross:
|
480 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
481 |
+
|
482 |
+
key = attn.to_k(encoder_hidden_states)
|
483 |
+
value = attn.to_v(encoder_hidden_states)
|
484 |
+
|
485 |
+
query = attn.head_to_batch_dim(query)
|
486 |
+
key = attn.head_to_batch_dim(key)
|
487 |
+
value = attn.head_to_batch_dim(value)
|
488 |
+
|
489 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
490 |
+
hidden_states = torch.bmm(attention_probs, value)
|
491 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
492 |
+
|
493 |
+
# linear proj
|
494 |
+
hidden_states = attn.to_out[0](hidden_states)
|
495 |
+
# dropout
|
496 |
+
hidden_states = attn.to_out[1](hidden_states)
|
497 |
+
|
498 |
+
if input_ndim == 4:
|
499 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
500 |
+
|
501 |
+
if attn.residual_connection:
|
502 |
+
hidden_states = hidden_states + residual
|
503 |
+
|
504 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
505 |
+
|
506 |
+
return hidden_states
|
507 |
+
|
508 |
+
|
509 |
+
class LoRALinearLayer(nn.Module):
|
510 |
+
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
|
511 |
+
super().__init__()
|
512 |
+
|
513 |
+
if rank > min(in_features, out_features):
|
514 |
+
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
515 |
+
|
516 |
+
self.down = nn.Linear(in_features, rank, bias=False)
|
517 |
+
self.up = nn.Linear(rank, out_features, bias=False)
|
518 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
519 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
520 |
+
self.network_alpha = network_alpha
|
521 |
+
self.rank = rank
|
522 |
+
|
523 |
+
nn.init.normal_(self.down.weight, std=1 / rank)
|
524 |
+
nn.init.zeros_(self.up.weight)
|
525 |
+
|
526 |
+
def forward(self, hidden_states):
|
527 |
+
orig_dtype = hidden_states.dtype
|
528 |
+
dtype = self.down.weight.dtype
|
529 |
+
|
530 |
+
down_hidden_states = self.down(hidden_states.to(dtype))
|
531 |
+
up_hidden_states = self.up(down_hidden_states)
|
532 |
+
|
533 |
+
if self.network_alpha is not None:
|
534 |
+
up_hidden_states *= self.network_alpha / self.rank
|
535 |
+
|
536 |
+
return up_hidden_states.to(orig_dtype)
|
537 |
+
|
538 |
+
|
539 |
+
class LoRAAttnProcessor(nn.Module):
|
540 |
+
r"""
|
541 |
+
Processor for implementing the LoRA attention mechanism.
|
542 |
+
|
543 |
+
Args:
|
544 |
+
hidden_size (`int`, *optional*):
|
545 |
+
The hidden size of the attention layer.
|
546 |
+
cross_attention_dim (`int`, *optional*):
|
547 |
+
The number of channels in the `encoder_hidden_states`.
|
548 |
+
rank (`int`, defaults to 4):
|
549 |
+
The dimension of the LoRA update matrices.
|
550 |
+
network_alpha (`int`, *optional*):
|
551 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
552 |
+
"""
|
553 |
+
|
554 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
555 |
+
super().__init__()
|
556 |
+
|
557 |
+
self.hidden_size = hidden_size
|
558 |
+
self.cross_attention_dim = cross_attention_dim
|
559 |
+
self.rank = rank
|
560 |
+
|
561 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
562 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
563 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
564 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
565 |
+
|
566 |
+
def __call__(
|
567 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
568 |
+
):
|
569 |
+
residual = hidden_states
|
570 |
+
|
571 |
+
if attn.spatial_norm is not None:
|
572 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
573 |
+
|
574 |
+
input_ndim = hidden_states.ndim
|
575 |
+
|
576 |
+
if input_ndim == 4:
|
577 |
+
batch_size, channel, height, width = hidden_states.shape
|
578 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
579 |
+
|
580 |
+
batch_size, sequence_length, _ = (
|
581 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
582 |
+
)
|
583 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
584 |
+
|
585 |
+
if attn.group_norm is not None:
|
586 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
587 |
+
|
588 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
589 |
+
query = attn.head_to_batch_dim(query)
|
590 |
+
|
591 |
+
if encoder_hidden_states is None:
|
592 |
+
encoder_hidden_states = hidden_states
|
593 |
+
elif attn.norm_cross:
|
594 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
595 |
+
|
596 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
597 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
598 |
+
|
599 |
+
key = attn.head_to_batch_dim(key)
|
600 |
+
value = attn.head_to_batch_dim(value)
|
601 |
+
|
602 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
603 |
+
hidden_states = torch.bmm(attention_probs, value)
|
604 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
605 |
+
|
606 |
+
# linear proj
|
607 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
608 |
+
# dropout
|
609 |
+
hidden_states = attn.to_out[1](hidden_states)
|
610 |
+
|
611 |
+
if input_ndim == 4:
|
612 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
613 |
+
|
614 |
+
if attn.residual_connection:
|
615 |
+
hidden_states = hidden_states + residual
|
616 |
+
|
617 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
618 |
+
|
619 |
+
return hidden_states
|
620 |
+
|
621 |
+
|
622 |
+
class CustomDiffusionAttnProcessor(nn.Module):
|
623 |
+
r"""
|
624 |
+
Processor for implementing attention for the Custom Diffusion method.
|
625 |
+
|
626 |
+
Args:
|
627 |
+
train_kv (`bool`, defaults to `True`):
|
628 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
629 |
+
train_q_out (`bool`, defaults to `True`):
|
630 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
631 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
632 |
+
The hidden size of the attention layer.
|
633 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
634 |
+
The number of channels in the `encoder_hidden_states`.
|
635 |
+
out_bias (`bool`, defaults to `True`):
|
636 |
+
Whether to include the bias parameter in `train_q_out`.
|
637 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
638 |
+
The dropout probability to use.
|
639 |
+
"""
|
640 |
+
|
641 |
+
def __init__(
|
642 |
+
self,
|
643 |
+
train_kv=True,
|
644 |
+
train_q_out=True,
|
645 |
+
hidden_size=None,
|
646 |
+
cross_attention_dim=None,
|
647 |
+
out_bias=True,
|
648 |
+
dropout=0.0,
|
649 |
+
):
|
650 |
+
super().__init__()
|
651 |
+
self.train_kv = train_kv
|
652 |
+
self.train_q_out = train_q_out
|
653 |
+
|
654 |
+
self.hidden_size = hidden_size
|
655 |
+
self.cross_attention_dim = cross_attention_dim
|
656 |
+
|
657 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
658 |
+
if self.train_kv:
|
659 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
660 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
661 |
+
if self.train_q_out:
|
662 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
663 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
664 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
665 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
666 |
+
|
667 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
668 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
669 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
670 |
+
if self.train_q_out:
|
671 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
672 |
+
else:
|
673 |
+
query = attn.to_q(hidden_states)
|
674 |
+
|
675 |
+
if encoder_hidden_states is None:
|
676 |
+
crossattn = False
|
677 |
+
encoder_hidden_states = hidden_states
|
678 |
+
else:
|
679 |
+
crossattn = True
|
680 |
+
if attn.norm_cross:
|
681 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
682 |
+
|
683 |
+
if self.train_kv:
|
684 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
685 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
686 |
+
else:
|
687 |
+
key = attn.to_k(encoder_hidden_states)
|
688 |
+
value = attn.to_v(encoder_hidden_states)
|
689 |
+
|
690 |
+
if crossattn:
|
691 |
+
detach = torch.ones_like(key)
|
692 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
693 |
+
key = detach * key + (1 - detach) * key.detach()
|
694 |
+
value = detach * value + (1 - detach) * value.detach()
|
695 |
+
|
696 |
+
query = attn.head_to_batch_dim(query)
|
697 |
+
key = attn.head_to_batch_dim(key)
|
698 |
+
value = attn.head_to_batch_dim(value)
|
699 |
+
|
700 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
701 |
+
hidden_states = torch.bmm(attention_probs, value)
|
702 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
703 |
+
|
704 |
+
if self.train_q_out:
|
705 |
+
# linear proj
|
706 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
707 |
+
# dropout
|
708 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
709 |
+
else:
|
710 |
+
# linear proj
|
711 |
+
hidden_states = attn.to_out[0](hidden_states)
|
712 |
+
# dropout
|
713 |
+
hidden_states = attn.to_out[1](hidden_states)
|
714 |
+
|
715 |
+
return hidden_states
|
716 |
+
|
717 |
+
|
718 |
+
class AttnAddedKVProcessor:
|
719 |
+
r"""
|
720 |
+
Processor for performing attention-related computations with extra learnable key and value matrices for the text
|
721 |
+
encoder.
|
722 |
+
"""
|
723 |
+
|
724 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
725 |
+
residual = hidden_states
|
726 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
727 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
728 |
+
|
729 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
730 |
+
|
731 |
+
if encoder_hidden_states is None:
|
732 |
+
encoder_hidden_states = hidden_states
|
733 |
+
elif attn.norm_cross:
|
734 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
735 |
+
|
736 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
737 |
+
|
738 |
+
query = attn.to_q(hidden_states)
|
739 |
+
query = attn.head_to_batch_dim(query)
|
740 |
+
|
741 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
742 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
743 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
744 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
745 |
+
|
746 |
+
if not attn.only_cross_attention:
|
747 |
+
key = attn.to_k(hidden_states)
|
748 |
+
value = attn.to_v(hidden_states)
|
749 |
+
key = attn.head_to_batch_dim(key)
|
750 |
+
value = attn.head_to_batch_dim(value)
|
751 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
752 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
753 |
+
else:
|
754 |
+
key = encoder_hidden_states_key_proj
|
755 |
+
value = encoder_hidden_states_value_proj
|
756 |
+
|
757 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
758 |
+
hidden_states = torch.bmm(attention_probs, value)
|
759 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
760 |
+
|
761 |
+
# linear proj
|
762 |
+
hidden_states = attn.to_out[0](hidden_states)
|
763 |
+
# dropout
|
764 |
+
hidden_states = attn.to_out[1](hidden_states)
|
765 |
+
|
766 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
767 |
+
hidden_states = hidden_states + residual
|
768 |
+
|
769 |
+
return hidden_states
|
770 |
+
|
771 |
+
|
772 |
+
class AttnAddedKVProcessor2_0:
|
773 |
+
r"""
|
774 |
+
Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
|
775 |
+
learnable key and value matrices for the text encoder.
|
776 |
+
"""
|
777 |
+
|
778 |
+
def __init__(self):
|
779 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
780 |
+
raise ImportError(
|
781 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
782 |
+
)
|
783 |
+
|
784 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
785 |
+
residual = hidden_states
|
786 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
787 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
788 |
+
|
789 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
790 |
+
|
791 |
+
if encoder_hidden_states is None:
|
792 |
+
encoder_hidden_states = hidden_states
|
793 |
+
elif attn.norm_cross:
|
794 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
795 |
+
|
796 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
797 |
+
|
798 |
+
query = attn.to_q(hidden_states)
|
799 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
800 |
+
|
801 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
802 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
803 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
804 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
805 |
+
|
806 |
+
if not attn.only_cross_attention:
|
807 |
+
key = attn.to_k(hidden_states)
|
808 |
+
value = attn.to_v(hidden_states)
|
809 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
810 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
811 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
812 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
813 |
+
else:
|
814 |
+
key = encoder_hidden_states_key_proj
|
815 |
+
value = encoder_hidden_states_value_proj
|
816 |
+
|
817 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
818 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
819 |
+
hidden_states = F.scaled_dot_product_attention(
|
820 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
821 |
+
)
|
822 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
823 |
+
|
824 |
+
# linear proj
|
825 |
+
hidden_states = attn.to_out[0](hidden_states)
|
826 |
+
# dropout
|
827 |
+
hidden_states = attn.to_out[1](hidden_states)
|
828 |
+
|
829 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
830 |
+
hidden_states = hidden_states + residual
|
831 |
+
|
832 |
+
return hidden_states
|
833 |
+
|
834 |
+
|
835 |
+
class LoRAAttnAddedKVProcessor(nn.Module):
|
836 |
+
r"""
|
837 |
+
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
|
838 |
+
encoder.
|
839 |
+
|
840 |
+
Args:
|
841 |
+
hidden_size (`int`, *optional*):
|
842 |
+
The hidden size of the attention layer.
|
843 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
844 |
+
The number of channels in the `encoder_hidden_states`.
|
845 |
+
rank (`int`, defaults to 4):
|
846 |
+
The dimension of the LoRA update matrices.
|
847 |
+
|
848 |
+
"""
|
849 |
+
|
850 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
851 |
+
super().__init__()
|
852 |
+
|
853 |
+
self.hidden_size = hidden_size
|
854 |
+
self.cross_attention_dim = cross_attention_dim
|
855 |
+
self.rank = rank
|
856 |
+
|
857 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
858 |
+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
859 |
+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
860 |
+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
861 |
+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
862 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
863 |
+
|
864 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
865 |
+
residual = hidden_states
|
866 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
867 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
868 |
+
|
869 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
870 |
+
|
871 |
+
if encoder_hidden_states is None:
|
872 |
+
encoder_hidden_states = hidden_states
|
873 |
+
elif attn.norm_cross:
|
874 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
875 |
+
|
876 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
877 |
+
|
878 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
879 |
+
query = attn.head_to_batch_dim(query)
|
880 |
+
|
881 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
882 |
+
encoder_hidden_states
|
883 |
+
)
|
884 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
885 |
+
encoder_hidden_states
|
886 |
+
)
|
887 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
888 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
889 |
+
|
890 |
+
if not attn.only_cross_attention:
|
891 |
+
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
892 |
+
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
893 |
+
key = attn.head_to_batch_dim(key)
|
894 |
+
value = attn.head_to_batch_dim(value)
|
895 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
896 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
897 |
+
else:
|
898 |
+
key = encoder_hidden_states_key_proj
|
899 |
+
value = encoder_hidden_states_value_proj
|
900 |
+
|
901 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
902 |
+
hidden_states = torch.bmm(attention_probs, value)
|
903 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
904 |
+
|
905 |
+
# linear proj
|
906 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
907 |
+
# dropout
|
908 |
+
hidden_states = attn.to_out[1](hidden_states)
|
909 |
+
|
910 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
911 |
+
hidden_states = hidden_states + residual
|
912 |
+
|
913 |
+
return hidden_states
|
914 |
+
|
915 |
+
|
916 |
+
class XFormersAttnAddedKVProcessor:
|
917 |
+
r"""
|
918 |
+
Processor for implementing memory efficient attention using xFormers.
|
919 |
+
|
920 |
+
Args:
|
921 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
922 |
+
The base
|
923 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
924 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
925 |
+
operator.
|
926 |
+
"""
|
927 |
+
|
928 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
929 |
+
self.attention_op = attention_op
|
930 |
+
|
931 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
932 |
+
residual = hidden_states
|
933 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
934 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
935 |
+
|
936 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
937 |
+
|
938 |
+
if encoder_hidden_states is None:
|
939 |
+
encoder_hidden_states = hidden_states
|
940 |
+
elif attn.norm_cross:
|
941 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
942 |
+
|
943 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
944 |
+
|
945 |
+
query = attn.to_q(hidden_states)
|
946 |
+
query = attn.head_to_batch_dim(query)
|
947 |
+
|
948 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
949 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
950 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
951 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
952 |
+
|
953 |
+
if not attn.only_cross_attention:
|
954 |
+
key = attn.to_k(hidden_states)
|
955 |
+
value = attn.to_v(hidden_states)
|
956 |
+
key = attn.head_to_batch_dim(key)
|
957 |
+
value = attn.head_to_batch_dim(value)
|
958 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
959 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
960 |
+
else:
|
961 |
+
key = encoder_hidden_states_key_proj
|
962 |
+
value = encoder_hidden_states_value_proj
|
963 |
+
|
964 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
965 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
966 |
+
)
|
967 |
+
hidden_states = hidden_states.to(query.dtype)
|
968 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
969 |
+
|
970 |
+
# linear proj
|
971 |
+
hidden_states = attn.to_out[0](hidden_states)
|
972 |
+
# dropout
|
973 |
+
hidden_states = attn.to_out[1](hidden_states)
|
974 |
+
|
975 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
976 |
+
hidden_states = hidden_states + residual
|
977 |
+
|
978 |
+
return hidden_states
|
979 |
+
|
980 |
+
|
981 |
+
class XFormersAttnProcessor:
|
982 |
+
r"""
|
983 |
+
Processor for implementing memory efficient attention using xFormers.
|
984 |
+
|
985 |
+
Args:
|
986 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
987 |
+
The base
|
988 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
989 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
990 |
+
operator.
|
991 |
+
"""
|
992 |
+
|
993 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
994 |
+
self.attention_op = attention_op
|
995 |
+
|
996 |
+
def __call__(
|
997 |
+
self,
|
998 |
+
attn: Attention,
|
999 |
+
hidden_states: torch.FloatTensor,
|
1000 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
1001 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1002 |
+
temb: Optional[torch.FloatTensor] = None,
|
1003 |
+
):
|
1004 |
+
residual = hidden_states
|
1005 |
+
|
1006 |
+
if attn.spatial_norm is not None:
|
1007 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1008 |
+
|
1009 |
+
input_ndim = hidden_states.ndim
|
1010 |
+
|
1011 |
+
if input_ndim == 4:
|
1012 |
+
batch_size, channel, height, width = hidden_states.shape
|
1013 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1014 |
+
|
1015 |
+
batch_size, key_tokens, _ = (
|
1016 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
1020 |
+
if attention_mask is not None:
|
1021 |
+
# expand our mask's singleton query_tokens dimension:
|
1022 |
+
# [batch*heads, 1, key_tokens] ->
|
1023 |
+
# [batch*heads, query_tokens, key_tokens]
|
1024 |
+
# so that it can be added as a bias onto the attention scores that xformers computes:
|
1025 |
+
# [batch*heads, query_tokens, key_tokens]
|
1026 |
+
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
1027 |
+
_, query_tokens, _ = hidden_states.shape
|
1028 |
+
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
1029 |
+
|
1030 |
+
if attn.group_norm is not None:
|
1031 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1032 |
+
|
1033 |
+
query = attn.to_q(hidden_states)
|
1034 |
+
|
1035 |
+
if encoder_hidden_states is None:
|
1036 |
+
encoder_hidden_states = hidden_states
|
1037 |
+
elif attn.norm_cross:
|
1038 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1039 |
+
|
1040 |
+
key = attn.to_k(encoder_hidden_states)
|
1041 |
+
value = attn.to_v(encoder_hidden_states)
|
1042 |
+
|
1043 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1044 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1045 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1046 |
+
|
1047 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1048 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1049 |
+
)
|
1050 |
+
hidden_states = hidden_states.to(query.dtype)
|
1051 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1052 |
+
|
1053 |
+
# linear proj
|
1054 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1055 |
+
# dropout
|
1056 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1057 |
+
|
1058 |
+
if input_ndim == 4:
|
1059 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1060 |
+
|
1061 |
+
if attn.residual_connection:
|
1062 |
+
hidden_states = hidden_states + residual
|
1063 |
+
|
1064 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1065 |
+
|
1066 |
+
return hidden_states
|
1067 |
+
|
1068 |
+
|
1069 |
+
class AttnProcessor2_0:
|
1070 |
+
r"""
|
1071 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
1072 |
+
"""
|
1073 |
+
|
1074 |
+
def __init__(self):
|
1075 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1076 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1077 |
+
|
1078 |
+
def __call__(
|
1079 |
+
self,
|
1080 |
+
attn: Attention,
|
1081 |
+
hidden_states,
|
1082 |
+
encoder_hidden_states=None,
|
1083 |
+
attention_mask=None,
|
1084 |
+
temb=None,
|
1085 |
+
):
|
1086 |
+
residual = hidden_states
|
1087 |
+
|
1088 |
+
if attn.spatial_norm is not None:
|
1089 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1090 |
+
|
1091 |
+
input_ndim = hidden_states.ndim
|
1092 |
+
|
1093 |
+
if input_ndim == 4:
|
1094 |
+
batch_size, channel, height, width = hidden_states.shape
|
1095 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1096 |
+
|
1097 |
+
batch_size, sequence_length, _ = (
|
1098 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1099 |
+
)
|
1100 |
+
inner_dim = hidden_states.shape[-1]
|
1101 |
+
|
1102 |
+
if attention_mask is not None:
|
1103 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1104 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1105 |
+
# (batch, heads, source_length, target_length)
|
1106 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1107 |
+
|
1108 |
+
if attn.group_norm is not None:
|
1109 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1110 |
+
|
1111 |
+
query = attn.to_q(hidden_states)
|
1112 |
+
|
1113 |
+
if encoder_hidden_states is None:
|
1114 |
+
encoder_hidden_states = hidden_states
|
1115 |
+
elif attn.norm_cross:
|
1116 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1117 |
+
|
1118 |
+
key = attn.to_k(encoder_hidden_states)
|
1119 |
+
value = attn.to_v(encoder_hidden_states)
|
1120 |
+
|
1121 |
+
head_dim = inner_dim // attn.heads
|
1122 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1123 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1124 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1125 |
+
|
1126 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
1127 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1128 |
+
hidden_states = F.scaled_dot_product_attention(
|
1129 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1133 |
+
hidden_states = hidden_states.to(query.dtype)
|
1134 |
+
|
1135 |
+
# linear proj
|
1136 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1137 |
+
# dropout
|
1138 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1139 |
+
|
1140 |
+
if input_ndim == 4:
|
1141 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1142 |
+
|
1143 |
+
if attn.residual_connection:
|
1144 |
+
hidden_states = hidden_states + residual
|
1145 |
+
|
1146 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1147 |
+
|
1148 |
+
return hidden_states
|
1149 |
+
|
1150 |
+
|
1151 |
+
class LoRAXFormersAttnProcessor(nn.Module):
|
1152 |
+
r"""
|
1153 |
+
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
|
1154 |
+
|
1155 |
+
Args:
|
1156 |
+
hidden_size (`int`, *optional*):
|
1157 |
+
The hidden size of the attention layer.
|
1158 |
+
cross_attention_dim (`int`, *optional*):
|
1159 |
+
The number of channels in the `encoder_hidden_states`.
|
1160 |
+
rank (`int`, defaults to 4):
|
1161 |
+
The dimension of the LoRA update matrices.
|
1162 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1163 |
+
The base
|
1164 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
1165 |
+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
1166 |
+
operator.
|
1167 |
+
network_alpha (`int`, *optional*):
|
1168 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1169 |
+
|
1170 |
+
"""
|
1171 |
+
|
1172 |
+
def __init__(
|
1173 |
+
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
|
1174 |
+
):
|
1175 |
+
super().__init__()
|
1176 |
+
|
1177 |
+
self.hidden_size = hidden_size
|
1178 |
+
self.cross_attention_dim = cross_attention_dim
|
1179 |
+
self.rank = rank
|
1180 |
+
self.attention_op = attention_op
|
1181 |
+
|
1182 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1183 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1184 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1185 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1186 |
+
|
1187 |
+
def __call__(
|
1188 |
+
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
|
1189 |
+
):
|
1190 |
+
residual = hidden_states
|
1191 |
+
|
1192 |
+
if attn.spatial_norm is not None:
|
1193 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1194 |
+
|
1195 |
+
input_ndim = hidden_states.ndim
|
1196 |
+
|
1197 |
+
if input_ndim == 4:
|
1198 |
+
batch_size, channel, height, width = hidden_states.shape
|
1199 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1200 |
+
|
1201 |
+
batch_size, sequence_length, _ = (
|
1202 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1203 |
+
)
|
1204 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1205 |
+
|
1206 |
+
if attn.group_norm is not None:
|
1207 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1208 |
+
|
1209 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1210 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1211 |
+
|
1212 |
+
if encoder_hidden_states is None:
|
1213 |
+
encoder_hidden_states = hidden_states
|
1214 |
+
elif attn.norm_cross:
|
1215 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1216 |
+
|
1217 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1218 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1219 |
+
|
1220 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1221 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1222 |
+
|
1223 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1224 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1225 |
+
)
|
1226 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1227 |
+
|
1228 |
+
# linear proj
|
1229 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1230 |
+
# dropout
|
1231 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1232 |
+
|
1233 |
+
if input_ndim == 4:
|
1234 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1235 |
+
|
1236 |
+
if attn.residual_connection:
|
1237 |
+
hidden_states = hidden_states + residual
|
1238 |
+
|
1239 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1240 |
+
|
1241 |
+
return hidden_states
|
1242 |
+
|
1243 |
+
|
1244 |
+
class LoRAAttnProcessor2_0(nn.Module):
|
1245 |
+
r"""
|
1246 |
+
Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product
|
1247 |
+
attention.
|
1248 |
+
|
1249 |
+
Args:
|
1250 |
+
hidden_size (`int`):
|
1251 |
+
The hidden size of the attention layer.
|
1252 |
+
cross_attention_dim (`int`, *optional*):
|
1253 |
+
The number of channels in the `encoder_hidden_states`.
|
1254 |
+
rank (`int`, defaults to 4):
|
1255 |
+
The dimension of the LoRA update matrices.
|
1256 |
+
network_alpha (`int`, *optional*):
|
1257 |
+
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
|
1258 |
+
"""
|
1259 |
+
|
1260 |
+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
|
1261 |
+
super().__init__()
|
1262 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
1263 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1264 |
+
|
1265 |
+
self.hidden_size = hidden_size
|
1266 |
+
self.cross_attention_dim = cross_attention_dim
|
1267 |
+
self.rank = rank
|
1268 |
+
|
1269 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1270 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1271 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
1272 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
1273 |
+
|
1274 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
1275 |
+
residual = hidden_states
|
1276 |
+
|
1277 |
+
input_ndim = hidden_states.ndim
|
1278 |
+
|
1279 |
+
if input_ndim == 4:
|
1280 |
+
batch_size, channel, height, width = hidden_states.shape
|
1281 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1282 |
+
|
1283 |
+
batch_size, sequence_length, _ = (
|
1284 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1285 |
+
)
|
1286 |
+
inner_dim = hidden_states.shape[-1]
|
1287 |
+
|
1288 |
+
if attention_mask is not None:
|
1289 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1290 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
1291 |
+
# (batch, heads, source_length, target_length)
|
1292 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
1293 |
+
|
1294 |
+
if attn.group_norm is not None:
|
1295 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1296 |
+
|
1297 |
+
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
1298 |
+
|
1299 |
+
if encoder_hidden_states is None:
|
1300 |
+
encoder_hidden_states = hidden_states
|
1301 |
+
elif attn.norm_cross:
|
1302 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1303 |
+
|
1304 |
+
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
1305 |
+
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
1306 |
+
|
1307 |
+
head_dim = inner_dim // attn.heads
|
1308 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1309 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1310 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
1311 |
+
|
1312 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
1313 |
+
hidden_states = F.scaled_dot_product_attention(
|
1314 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
1315 |
+
)
|
1316 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
1317 |
+
hidden_states = hidden_states.to(query.dtype)
|
1318 |
+
|
1319 |
+
# linear proj
|
1320 |
+
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
1321 |
+
# dropout
|
1322 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1323 |
+
|
1324 |
+
if input_ndim == 4:
|
1325 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1326 |
+
|
1327 |
+
if attn.residual_connection:
|
1328 |
+
hidden_states = hidden_states + residual
|
1329 |
+
|
1330 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1331 |
+
|
1332 |
+
return hidden_states
|
1333 |
+
|
1334 |
+
|
1335 |
+
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
1336 |
+
r"""
|
1337 |
+
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
1338 |
+
|
1339 |
+
Args:
|
1340 |
+
train_kv (`bool`, defaults to `True`):
|
1341 |
+
Whether to newly train the key and value matrices corresponding to the text features.
|
1342 |
+
train_q_out (`bool`, defaults to `True`):
|
1343 |
+
Whether to newly train query matrices corresponding to the latent image features.
|
1344 |
+
hidden_size (`int`, *optional*, defaults to `None`):
|
1345 |
+
The hidden size of the attention layer.
|
1346 |
+
cross_attention_dim (`int`, *optional*, defaults to `None`):
|
1347 |
+
The number of channels in the `encoder_hidden_states`.
|
1348 |
+
out_bias (`bool`, defaults to `True`):
|
1349 |
+
Whether to include the bias parameter in `train_q_out`.
|
1350 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
1351 |
+
The dropout probability to use.
|
1352 |
+
attention_op (`Callable`, *optional*, defaults to `None`):
|
1353 |
+
The base
|
1354 |
+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
|
1355 |
+
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
|
1356 |
+
"""
|
1357 |
+
|
1358 |
+
def __init__(
|
1359 |
+
self,
|
1360 |
+
train_kv=True,
|
1361 |
+
train_q_out=False,
|
1362 |
+
hidden_size=None,
|
1363 |
+
cross_attention_dim=None,
|
1364 |
+
out_bias=True,
|
1365 |
+
dropout=0.0,
|
1366 |
+
attention_op: Optional[Callable] = None,
|
1367 |
+
):
|
1368 |
+
super().__init__()
|
1369 |
+
self.train_kv = train_kv
|
1370 |
+
self.train_q_out = train_q_out
|
1371 |
+
|
1372 |
+
self.hidden_size = hidden_size
|
1373 |
+
self.cross_attention_dim = cross_attention_dim
|
1374 |
+
self.attention_op = attention_op
|
1375 |
+
|
1376 |
+
# `_custom_diffusion` id for easy serialization and loading.
|
1377 |
+
if self.train_kv:
|
1378 |
+
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1379 |
+
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
1380 |
+
if self.train_q_out:
|
1381 |
+
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
1382 |
+
self.to_out_custom_diffusion = nn.ModuleList([])
|
1383 |
+
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
1384 |
+
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
1385 |
+
|
1386 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1387 |
+
batch_size, sequence_length, _ = (
|
1388 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1389 |
+
)
|
1390 |
+
|
1391 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1392 |
+
|
1393 |
+
if self.train_q_out:
|
1394 |
+
query = self.to_q_custom_diffusion(hidden_states)
|
1395 |
+
else:
|
1396 |
+
query = attn.to_q(hidden_states)
|
1397 |
+
|
1398 |
+
if encoder_hidden_states is None:
|
1399 |
+
crossattn = False
|
1400 |
+
encoder_hidden_states = hidden_states
|
1401 |
+
else:
|
1402 |
+
crossattn = True
|
1403 |
+
if attn.norm_cross:
|
1404 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1405 |
+
|
1406 |
+
if self.train_kv:
|
1407 |
+
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
1408 |
+
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
1409 |
+
else:
|
1410 |
+
key = attn.to_k(encoder_hidden_states)
|
1411 |
+
value = attn.to_v(encoder_hidden_states)
|
1412 |
+
|
1413 |
+
if crossattn:
|
1414 |
+
detach = torch.ones_like(key)
|
1415 |
+
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
1416 |
+
key = detach * key + (1 - detach) * key.detach()
|
1417 |
+
value = detach * value + (1 - detach) * value.detach()
|
1418 |
+
|
1419 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
1420 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
1421 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
1422 |
+
|
1423 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
1424 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
1425 |
+
)
|
1426 |
+
hidden_states = hidden_states.to(query.dtype)
|
1427 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1428 |
+
|
1429 |
+
if self.train_q_out:
|
1430 |
+
# linear proj
|
1431 |
+
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
1432 |
+
# dropout
|
1433 |
+
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
1434 |
+
else:
|
1435 |
+
# linear proj
|
1436 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1437 |
+
# dropout
|
1438 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1439 |
+
return hidden_states
|
1440 |
+
|
1441 |
+
|
1442 |
+
class SlicedAttnProcessor:
|
1443 |
+
r"""
|
1444 |
+
Processor for implementing sliced attention.
|
1445 |
+
|
1446 |
+
Args:
|
1447 |
+
slice_size (`int`, *optional*):
|
1448 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1449 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1450 |
+
"""
|
1451 |
+
|
1452 |
+
def __init__(self, slice_size):
|
1453 |
+
self.slice_size = slice_size
|
1454 |
+
|
1455 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
1456 |
+
residual = hidden_states
|
1457 |
+
|
1458 |
+
input_ndim = hidden_states.ndim
|
1459 |
+
|
1460 |
+
if input_ndim == 4:
|
1461 |
+
batch_size, channel, height, width = hidden_states.shape
|
1462 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
1463 |
+
|
1464 |
+
batch_size, sequence_length, _ = (
|
1465 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
1466 |
+
)
|
1467 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1468 |
+
|
1469 |
+
if attn.group_norm is not None:
|
1470 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1471 |
+
|
1472 |
+
query = attn.to_q(hidden_states)
|
1473 |
+
dim = query.shape[-1]
|
1474 |
+
query = attn.head_to_batch_dim(query)
|
1475 |
+
|
1476 |
+
if encoder_hidden_states is None:
|
1477 |
+
encoder_hidden_states = hidden_states
|
1478 |
+
elif attn.norm_cross:
|
1479 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1480 |
+
|
1481 |
+
key = attn.to_k(encoder_hidden_states)
|
1482 |
+
value = attn.to_v(encoder_hidden_states)
|
1483 |
+
key = attn.head_to_batch_dim(key)
|
1484 |
+
value = attn.head_to_batch_dim(value)
|
1485 |
+
|
1486 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1487 |
+
hidden_states = torch.zeros(
|
1488 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1489 |
+
)
|
1490 |
+
|
1491 |
+
for i in range(batch_size_attention // self.slice_size):
|
1492 |
+
start_idx = i * self.slice_size
|
1493 |
+
end_idx = (i + 1) * self.slice_size
|
1494 |
+
|
1495 |
+
query_slice = query[start_idx:end_idx]
|
1496 |
+
key_slice = key[start_idx:end_idx]
|
1497 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1498 |
+
|
1499 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1500 |
+
|
1501 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1502 |
+
|
1503 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1504 |
+
|
1505 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1506 |
+
|
1507 |
+
# linear proj
|
1508 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1509 |
+
# dropout
|
1510 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1511 |
+
|
1512 |
+
if input_ndim == 4:
|
1513 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
1514 |
+
|
1515 |
+
if attn.residual_connection:
|
1516 |
+
hidden_states = hidden_states + residual
|
1517 |
+
|
1518 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
1519 |
+
|
1520 |
+
return hidden_states
|
1521 |
+
|
1522 |
+
|
1523 |
+
class SlicedAttnAddedKVProcessor:
|
1524 |
+
r"""
|
1525 |
+
Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
|
1526 |
+
|
1527 |
+
Args:
|
1528 |
+
slice_size (`int`, *optional*):
|
1529 |
+
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
1530 |
+
`attention_head_dim` must be a multiple of the `slice_size`.
|
1531 |
+
"""
|
1532 |
+
|
1533 |
+
def __init__(self, slice_size):
|
1534 |
+
self.slice_size = slice_size
|
1535 |
+
|
1536 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
1537 |
+
residual = hidden_states
|
1538 |
+
|
1539 |
+
if attn.spatial_norm is not None:
|
1540 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
1541 |
+
|
1542 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
1543 |
+
|
1544 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
1545 |
+
|
1546 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
1547 |
+
|
1548 |
+
if encoder_hidden_states is None:
|
1549 |
+
encoder_hidden_states = hidden_states
|
1550 |
+
elif attn.norm_cross:
|
1551 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
1552 |
+
|
1553 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
1554 |
+
|
1555 |
+
query = attn.to_q(hidden_states)
|
1556 |
+
dim = query.shape[-1]
|
1557 |
+
query = attn.head_to_batch_dim(query)
|
1558 |
+
|
1559 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
1560 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
1561 |
+
|
1562 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
1563 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
1564 |
+
|
1565 |
+
if not attn.only_cross_attention:
|
1566 |
+
key = attn.to_k(hidden_states)
|
1567 |
+
value = attn.to_v(hidden_states)
|
1568 |
+
key = attn.head_to_batch_dim(key)
|
1569 |
+
value = attn.head_to_batch_dim(value)
|
1570 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
1571 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
1572 |
+
else:
|
1573 |
+
key = encoder_hidden_states_key_proj
|
1574 |
+
value = encoder_hidden_states_value_proj
|
1575 |
+
|
1576 |
+
batch_size_attention, query_tokens, _ = query.shape
|
1577 |
+
hidden_states = torch.zeros(
|
1578 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
1579 |
+
)
|
1580 |
+
|
1581 |
+
for i in range(batch_size_attention // self.slice_size):
|
1582 |
+
start_idx = i * self.slice_size
|
1583 |
+
end_idx = (i + 1) * self.slice_size
|
1584 |
+
|
1585 |
+
query_slice = query[start_idx:end_idx]
|
1586 |
+
key_slice = key[start_idx:end_idx]
|
1587 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
1588 |
+
|
1589 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
1590 |
+
|
1591 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
1592 |
+
|
1593 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
1594 |
+
|
1595 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
1596 |
+
|
1597 |
+
# linear proj
|
1598 |
+
hidden_states = attn.to_out[0](hidden_states)
|
1599 |
+
# dropout
|
1600 |
+
hidden_states = attn.to_out[1](hidden_states)
|
1601 |
+
|
1602 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
1603 |
+
hidden_states = hidden_states + residual
|
1604 |
+
|
1605 |
+
return hidden_states
|
1606 |
+
|
1607 |
+
|
1608 |
+
AttentionProcessor = Union[
|
1609 |
+
AttnProcessor,
|
1610 |
+
AttnProcessor2_0,
|
1611 |
+
XFormersAttnProcessor,
|
1612 |
+
SlicedAttnProcessor,
|
1613 |
+
AttnAddedKVProcessor,
|
1614 |
+
SlicedAttnAddedKVProcessor,
|
1615 |
+
AttnAddedKVProcessor2_0,
|
1616 |
+
XFormersAttnAddedKVProcessor,
|
1617 |
+
LoRAAttnProcessor,
|
1618 |
+
LoRAXFormersAttnProcessor,
|
1619 |
+
LoRAAttnProcessor2_0,
|
1620 |
+
LoRAAttnAddedKVProcessor,
|
1621 |
+
CustomDiffusionAttnProcessor,
|
1622 |
+
CustomDiffusionXFormersAttnProcessor,
|
1623 |
+
]
|
1624 |
+
|
1625 |
+
|
1626 |
+
class SpatialNorm(nn.Module):
|
1627 |
+
"""
|
1628 |
+
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
1629 |
+
"""
|
1630 |
+
|
1631 |
+
def __init__(
|
1632 |
+
self,
|
1633 |
+
f_channels,
|
1634 |
+
zq_channels,
|
1635 |
+
):
|
1636 |
+
super().__init__()
|
1637 |
+
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
1638 |
+
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1639 |
+
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
1640 |
+
|
1641 |
+
def forward(self, f, zq):
|
1642 |
+
f_size = f.shape[-2:]
|
1643 |
+
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
1644 |
+
norm_f = self.norm_layer(f)
|
1645 |
+
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
1646 |
+
return new_f
|
diffusers/models/dual_transformer_2d.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Optional
|
15 |
+
|
16 |
+
from torch import nn
|
17 |
+
|
18 |
+
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
19 |
+
|
20 |
+
|
21 |
+
class DualTransformer2DModel(nn.Module):
|
22 |
+
"""
|
23 |
+
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
27 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
28 |
+
in_channels (`int`, *optional*):
|
29 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
30 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
31 |
+
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
32 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
33 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
34 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
35 |
+
`ImagePositionalEmbeddings`.
|
36 |
+
num_vector_embeds (`int`, *optional*):
|
37 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
38 |
+
Includes the class for the masked latent pixel.
|
39 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
40 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
41 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
42 |
+
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
43 |
+
up to but not more than steps than `num_embeds_ada_norm`.
|
44 |
+
attention_bias (`bool`, *optional*):
|
45 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
num_attention_heads: int = 16,
|
51 |
+
attention_head_dim: int = 88,
|
52 |
+
in_channels: Optional[int] = None,
|
53 |
+
num_layers: int = 1,
|
54 |
+
dropout: float = 0.0,
|
55 |
+
norm_num_groups: int = 32,
|
56 |
+
cross_attention_dim: Optional[int] = None,
|
57 |
+
attention_bias: bool = False,
|
58 |
+
sample_size: Optional[int] = None,
|
59 |
+
num_vector_embeds: Optional[int] = None,
|
60 |
+
activation_fn: str = "geglu",
|
61 |
+
num_embeds_ada_norm: Optional[int] = None,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
self.transformers = nn.ModuleList(
|
65 |
+
[
|
66 |
+
Transformer2DModel(
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=attention_head_dim,
|
69 |
+
in_channels=in_channels,
|
70 |
+
num_layers=num_layers,
|
71 |
+
dropout=dropout,
|
72 |
+
norm_num_groups=norm_num_groups,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attention_bias=attention_bias,
|
75 |
+
sample_size=sample_size,
|
76 |
+
num_vector_embeds=num_vector_embeds,
|
77 |
+
activation_fn=activation_fn,
|
78 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
79 |
+
)
|
80 |
+
for _ in range(2)
|
81 |
+
]
|
82 |
+
)
|
83 |
+
|
84 |
+
# Variables that can be set by a pipeline:
|
85 |
+
|
86 |
+
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
87 |
+
self.mix_ratio = 0.5
|
88 |
+
|
89 |
+
# The shape of `encoder_hidden_states` is expected to be
|
90 |
+
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
91 |
+
self.condition_lengths = [77, 257]
|
92 |
+
|
93 |
+
# Which transformer to use to encode which condition.
|
94 |
+
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
95 |
+
self.transformer_index_for_condition = [1, 0]
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
hidden_states,
|
100 |
+
encoder_hidden_states,
|
101 |
+
timestep=None,
|
102 |
+
attention_mask=None,
|
103 |
+
cross_attention_kwargs=None,
|
104 |
+
return_dict: bool = True,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Args:
|
108 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
109 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
110 |
+
hidden_states
|
111 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
112 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
113 |
+
self-attention.
|
114 |
+
timestep ( `torch.long`, *optional*):
|
115 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
116 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
117 |
+
Optional attention mask to be applied in Attention
|
118 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
119 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
123 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
124 |
+
returning a tuple, the first element is the sample tensor.
|
125 |
+
"""
|
126 |
+
input_states = hidden_states
|
127 |
+
|
128 |
+
encoded_states = []
|
129 |
+
tokens_start = 0
|
130 |
+
# attention_mask is not used yet
|
131 |
+
for i in range(2):
|
132 |
+
# for each of the two transformers, pass the corresponding condition tokens
|
133 |
+
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
134 |
+
transformer_index = self.transformer_index_for_condition[i]
|
135 |
+
encoded_state = self.transformers[transformer_index](
|
136 |
+
input_states,
|
137 |
+
encoder_hidden_states=condition_state,
|
138 |
+
timestep=timestep,
|
139 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
140 |
+
return_dict=False,
|
141 |
+
)[0]
|
142 |
+
encoded_states.append(encoded_state - input_states)
|
143 |
+
tokens_start += self.condition_lengths[i]
|
144 |
+
|
145 |
+
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
146 |
+
output_states = output_states + input_states
|
147 |
+
|
148 |
+
if not return_dict:
|
149 |
+
return (output_states,)
|
150 |
+
|
151 |
+
return Transformer2DModelOutput(sample=output_states)
|
diffusers/models/embeddings.py
ADDED
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import math
|
16 |
+
from typing import Optional
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from .activations import get_activation
|
23 |
+
|
24 |
+
|
25 |
+
def get_timestep_embedding(
|
26 |
+
timesteps: torch.Tensor,
|
27 |
+
embedding_dim: int,
|
28 |
+
flip_sin_to_cos: bool = False,
|
29 |
+
downscale_freq_shift: float = 1,
|
30 |
+
scale: float = 1,
|
31 |
+
max_period: int = 10000,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
35 |
+
|
36 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
37 |
+
These may be fractional.
|
38 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
39 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
40 |
+
"""
|
41 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
42 |
+
|
43 |
+
half_dim = embedding_dim // 2
|
44 |
+
exponent = -math.log(max_period) * torch.arange(
|
45 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
46 |
+
)
|
47 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
48 |
+
|
49 |
+
emb = torch.exp(exponent)
|
50 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
51 |
+
|
52 |
+
# scale embeddings
|
53 |
+
emb = scale * emb
|
54 |
+
|
55 |
+
# concat sine and cosine embeddings
|
56 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
57 |
+
|
58 |
+
# flip sine and cosine embeddings
|
59 |
+
if flip_sin_to_cos:
|
60 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
61 |
+
|
62 |
+
# zero pad
|
63 |
+
if embedding_dim % 2 == 1:
|
64 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
65 |
+
return emb
|
66 |
+
|
67 |
+
|
68 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
69 |
+
"""
|
70 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
71 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
72 |
+
"""
|
73 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
74 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
75 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
76 |
+
grid = np.stack(grid, axis=0)
|
77 |
+
|
78 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
79 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
80 |
+
if cls_token and extra_tokens > 0:
|
81 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
82 |
+
return pos_embed
|
83 |
+
|
84 |
+
|
85 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
86 |
+
if embed_dim % 2 != 0:
|
87 |
+
raise ValueError("embed_dim must be divisible by 2")
|
88 |
+
|
89 |
+
# use half of dimensions to encode grid_h
|
90 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
91 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
92 |
+
|
93 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
94 |
+
return emb
|
95 |
+
|
96 |
+
|
97 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
98 |
+
"""
|
99 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
100 |
+
"""
|
101 |
+
if embed_dim % 2 != 0:
|
102 |
+
raise ValueError("embed_dim must be divisible by 2")
|
103 |
+
|
104 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
105 |
+
omega /= embed_dim / 2.0
|
106 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
107 |
+
|
108 |
+
pos = pos.reshape(-1) # (M,)
|
109 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
110 |
+
|
111 |
+
emb_sin = np.sin(out) # (M, D/2)
|
112 |
+
emb_cos = np.cos(out) # (M, D/2)
|
113 |
+
|
114 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
115 |
+
return emb
|
116 |
+
|
117 |
+
|
118 |
+
class PatchEmbed(nn.Module):
|
119 |
+
"""2D Image to Patch Embedding"""
|
120 |
+
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
height=224,
|
124 |
+
width=224,
|
125 |
+
patch_size=16,
|
126 |
+
in_channels=3,
|
127 |
+
embed_dim=768,
|
128 |
+
layer_norm=False,
|
129 |
+
flatten=True,
|
130 |
+
bias=True,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
num_patches = (height // patch_size) * (width // patch_size)
|
135 |
+
self.flatten = flatten
|
136 |
+
self.layer_norm = layer_norm
|
137 |
+
|
138 |
+
self.proj = nn.Conv2d(
|
139 |
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
140 |
+
)
|
141 |
+
if layer_norm:
|
142 |
+
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
143 |
+
else:
|
144 |
+
self.norm = None
|
145 |
+
|
146 |
+
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
147 |
+
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
148 |
+
|
149 |
+
def forward(self, latent):
|
150 |
+
latent = self.proj(latent)
|
151 |
+
if self.flatten:
|
152 |
+
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
153 |
+
if self.layer_norm:
|
154 |
+
latent = self.norm(latent)
|
155 |
+
return latent + self.pos_embed
|
156 |
+
|
157 |
+
|
158 |
+
class TimestepEmbedding(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self,
|
161 |
+
in_channels: int,
|
162 |
+
time_embed_dim: int,
|
163 |
+
act_fn: str = "silu",
|
164 |
+
out_dim: int = None,
|
165 |
+
post_act_fn: Optional[str] = None,
|
166 |
+
cond_proj_dim=None,
|
167 |
+
):
|
168 |
+
super().__init__()
|
169 |
+
|
170 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
171 |
+
|
172 |
+
if cond_proj_dim is not None:
|
173 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
174 |
+
else:
|
175 |
+
self.cond_proj = None
|
176 |
+
|
177 |
+
self.act = get_activation(act_fn)
|
178 |
+
|
179 |
+
if out_dim is not None:
|
180 |
+
time_embed_dim_out = out_dim
|
181 |
+
else:
|
182 |
+
time_embed_dim_out = time_embed_dim
|
183 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
184 |
+
|
185 |
+
if post_act_fn is None:
|
186 |
+
self.post_act = None
|
187 |
+
else:
|
188 |
+
self.post_act = get_activation(post_act_fn)
|
189 |
+
|
190 |
+
def forward(self, sample, condition=None):
|
191 |
+
if condition is not None:
|
192 |
+
sample = sample + self.cond_proj(condition)
|
193 |
+
sample = self.linear_1(sample)
|
194 |
+
|
195 |
+
if self.act is not None:
|
196 |
+
sample = self.act(sample)
|
197 |
+
|
198 |
+
sample = self.linear_2(sample)
|
199 |
+
|
200 |
+
if self.post_act is not None:
|
201 |
+
sample = self.post_act(sample)
|
202 |
+
return sample
|
203 |
+
|
204 |
+
|
205 |
+
class Timesteps(nn.Module):
|
206 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
207 |
+
super().__init__()
|
208 |
+
self.num_channels = num_channels
|
209 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
210 |
+
self.downscale_freq_shift = downscale_freq_shift
|
211 |
+
|
212 |
+
def forward(self, timesteps):
|
213 |
+
t_emb = get_timestep_embedding(
|
214 |
+
timesteps,
|
215 |
+
self.num_channels,
|
216 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
217 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
218 |
+
)
|
219 |
+
return t_emb
|
220 |
+
|
221 |
+
|
222 |
+
class GaussianFourierProjection(nn.Module):
|
223 |
+
"""Gaussian Fourier embeddings for noise levels."""
|
224 |
+
|
225 |
+
def __init__(
|
226 |
+
self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
|
227 |
+
):
|
228 |
+
super().__init__()
|
229 |
+
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
230 |
+
self.log = log
|
231 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
232 |
+
|
233 |
+
if set_W_to_weight:
|
234 |
+
# to delete later
|
235 |
+
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
236 |
+
|
237 |
+
self.weight = self.W
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
if self.log:
|
241 |
+
x = torch.log(x)
|
242 |
+
|
243 |
+
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
244 |
+
|
245 |
+
if self.flip_sin_to_cos:
|
246 |
+
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
247 |
+
else:
|
248 |
+
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
249 |
+
return out
|
250 |
+
|
251 |
+
|
252 |
+
class ImagePositionalEmbeddings(nn.Module):
|
253 |
+
"""
|
254 |
+
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
255 |
+
height and width of the latent space.
|
256 |
+
|
257 |
+
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
258 |
+
|
259 |
+
For VQ-diffusion:
|
260 |
+
|
261 |
+
Output vector embeddings are used as input for the transformer.
|
262 |
+
|
263 |
+
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
num_embed (`int`):
|
267 |
+
Number of embeddings for the latent pixels embeddings.
|
268 |
+
height (`int`):
|
269 |
+
Height of the latent image i.e. the number of height embeddings.
|
270 |
+
width (`int`):
|
271 |
+
Width of the latent image i.e. the number of width embeddings.
|
272 |
+
embed_dim (`int`):
|
273 |
+
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
274 |
+
"""
|
275 |
+
|
276 |
+
def __init__(
|
277 |
+
self,
|
278 |
+
num_embed: int,
|
279 |
+
height: int,
|
280 |
+
width: int,
|
281 |
+
embed_dim: int,
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
|
285 |
+
self.height = height
|
286 |
+
self.width = width
|
287 |
+
self.num_embed = num_embed
|
288 |
+
self.embed_dim = embed_dim
|
289 |
+
|
290 |
+
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
291 |
+
self.height_emb = nn.Embedding(self.height, embed_dim)
|
292 |
+
self.width_emb = nn.Embedding(self.width, embed_dim)
|
293 |
+
|
294 |
+
def forward(self, index):
|
295 |
+
emb = self.emb(index)
|
296 |
+
|
297 |
+
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
298 |
+
|
299 |
+
# 1 x H x D -> 1 x H x 1 x D
|
300 |
+
height_emb = height_emb.unsqueeze(2)
|
301 |
+
|
302 |
+
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
303 |
+
|
304 |
+
# 1 x W x D -> 1 x 1 x W x D
|
305 |
+
width_emb = width_emb.unsqueeze(1)
|
306 |
+
|
307 |
+
pos_emb = height_emb + width_emb
|
308 |
+
|
309 |
+
# 1 x H x W x D -> 1 x L xD
|
310 |
+
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
311 |
+
|
312 |
+
emb = emb + pos_emb[:, : emb.shape[1], :]
|
313 |
+
|
314 |
+
return emb
|
315 |
+
|
316 |
+
|
317 |
+
class LabelEmbedding(nn.Module):
|
318 |
+
"""
|
319 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
num_classes (`int`): The number of classes.
|
323 |
+
hidden_size (`int`): The size of the vector embeddings.
|
324 |
+
dropout_prob (`float`): The probability of dropping a label.
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
328 |
+
super().__init__()
|
329 |
+
use_cfg_embedding = dropout_prob > 0
|
330 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
331 |
+
self.num_classes = num_classes
|
332 |
+
self.dropout_prob = dropout_prob
|
333 |
+
|
334 |
+
def token_drop(self, labels, force_drop_ids=None):
|
335 |
+
"""
|
336 |
+
Drops labels to enable classifier-free guidance.
|
337 |
+
"""
|
338 |
+
if force_drop_ids is None:
|
339 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
340 |
+
else:
|
341 |
+
drop_ids = torch.tensor(force_drop_ids == 1)
|
342 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
343 |
+
return labels
|
344 |
+
|
345 |
+
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
346 |
+
use_dropout = self.dropout_prob > 0
|
347 |
+
if (self.training and use_dropout) or (force_drop_ids is not None):
|
348 |
+
labels = self.token_drop(labels, force_drop_ids)
|
349 |
+
embeddings = self.embedding_table(labels)
|
350 |
+
return embeddings
|
351 |
+
|
352 |
+
|
353 |
+
class TextImageProjection(nn.Module):
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
text_embed_dim: int = 1024,
|
357 |
+
image_embed_dim: int = 768,
|
358 |
+
cross_attention_dim: int = 768,
|
359 |
+
num_image_text_embeds: int = 10,
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
|
363 |
+
self.num_image_text_embeds = num_image_text_embeds
|
364 |
+
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
365 |
+
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
366 |
+
|
367 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
368 |
+
batch_size = text_embeds.shape[0]
|
369 |
+
|
370 |
+
# image
|
371 |
+
image_text_embeds = self.image_embeds(image_embeds)
|
372 |
+
image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
373 |
+
|
374 |
+
# text
|
375 |
+
text_embeds = self.text_proj(text_embeds)
|
376 |
+
|
377 |
+
return torch.cat([image_text_embeds, text_embeds], dim=1)
|
378 |
+
|
379 |
+
|
380 |
+
class CombinedTimestepLabelEmbeddings(nn.Module):
|
381 |
+
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
382 |
+
super().__init__()
|
383 |
+
|
384 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
385 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
386 |
+
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
387 |
+
|
388 |
+
def forward(self, timestep, class_labels, hidden_dtype=None):
|
389 |
+
timesteps_proj = self.time_proj(timestep)
|
390 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
391 |
+
|
392 |
+
class_labels = self.class_embedder(class_labels) # (N, D)
|
393 |
+
|
394 |
+
conditioning = timesteps_emb + class_labels # (N, D)
|
395 |
+
|
396 |
+
return conditioning
|
397 |
+
|
398 |
+
|
399 |
+
class TextTimeEmbedding(nn.Module):
|
400 |
+
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
401 |
+
super().__init__()
|
402 |
+
self.norm1 = nn.LayerNorm(encoder_dim)
|
403 |
+
self.pool = AttentionPooling(num_heads, encoder_dim)
|
404 |
+
self.proj = nn.Linear(encoder_dim, time_embed_dim)
|
405 |
+
self.norm2 = nn.LayerNorm(time_embed_dim)
|
406 |
+
|
407 |
+
def forward(self, hidden_states):
|
408 |
+
hidden_states = self.norm1(hidden_states)
|
409 |
+
hidden_states = self.pool(hidden_states)
|
410 |
+
hidden_states = self.proj(hidden_states)
|
411 |
+
hidden_states = self.norm2(hidden_states)
|
412 |
+
return hidden_states
|
413 |
+
|
414 |
+
|
415 |
+
class TextImageTimeEmbedding(nn.Module):
|
416 |
+
def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
|
417 |
+
super().__init__()
|
418 |
+
self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
|
419 |
+
self.text_norm = nn.LayerNorm(time_embed_dim)
|
420 |
+
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
421 |
+
|
422 |
+
def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor):
|
423 |
+
# text
|
424 |
+
time_text_embeds = self.text_proj(text_embeds)
|
425 |
+
time_text_embeds = self.text_norm(time_text_embeds)
|
426 |
+
|
427 |
+
# image
|
428 |
+
time_image_embeds = self.image_proj(image_embeds)
|
429 |
+
|
430 |
+
return time_image_embeds + time_text_embeds
|
431 |
+
|
432 |
+
|
433 |
+
class AttentionPooling(nn.Module):
|
434 |
+
# Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
|
435 |
+
|
436 |
+
def __init__(self, num_heads, embed_dim, dtype=None):
|
437 |
+
super().__init__()
|
438 |
+
self.dtype = dtype
|
439 |
+
self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
|
440 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
441 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
442 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
|
443 |
+
self.num_heads = num_heads
|
444 |
+
self.dim_per_head = embed_dim // self.num_heads
|
445 |
+
|
446 |
+
def forward(self, x):
|
447 |
+
bs, length, width = x.size()
|
448 |
+
|
449 |
+
def shape(x):
|
450 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
451 |
+
x = x.view(bs, -1, self.num_heads, self.dim_per_head)
|
452 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
453 |
+
x = x.transpose(1, 2)
|
454 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
455 |
+
x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
|
456 |
+
# (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
|
457 |
+
x = x.transpose(1, 2)
|
458 |
+
return x
|
459 |
+
|
460 |
+
class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
|
461 |
+
x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
|
462 |
+
|
463 |
+
# (bs*n_heads, class_token_length, dim_per_head)
|
464 |
+
q = shape(self.q_proj(class_token))
|
465 |
+
# (bs*n_heads, length+class_token_length, dim_per_head)
|
466 |
+
k = shape(self.k_proj(x))
|
467 |
+
v = shape(self.v_proj(x))
|
468 |
+
|
469 |
+
# (bs*n_heads, class_token_length, length+class_token_length):
|
470 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
|
471 |
+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
|
472 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
473 |
+
|
474 |
+
# (bs*n_heads, dim_per_head, class_token_length)
|
475 |
+
a = torch.einsum("bts,bcs->bct", weight, v)
|
476 |
+
|
477 |
+
# (bs, length+1, width)
|
478 |
+
a = a.reshape(bs, -1, 1).transpose(1, 2)
|
479 |
+
|
480 |
+
return a[:, 0, :] # cls_token
|
diffusers/models/loaders.py
ADDED
@@ -0,0 +1,1481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import warnings
|
17 |
+
from collections import defaultdict
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Callable, Dict, List, Optional, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
|
25 |
+
from .attention_processor import (
|
26 |
+
AttnAddedKVProcessor,
|
27 |
+
AttnAddedKVProcessor2_0,
|
28 |
+
CustomDiffusionAttnProcessor,
|
29 |
+
CustomDiffusionXFormersAttnProcessor,
|
30 |
+
LoRAAttnAddedKVProcessor,
|
31 |
+
LoRAAttnProcessor,
|
32 |
+
LoRAAttnProcessor2_0,
|
33 |
+
LoRAXFormersAttnProcessor,
|
34 |
+
SlicedAttnAddedKVProcessor,
|
35 |
+
XFormersAttnProcessor,
|
36 |
+
)
|
37 |
+
from ..utils.constants import DIFFUSERS_CACHE, TEXT_ENCODER_ATTN_MODULE
|
38 |
+
from ..utils.hub_utils import HF_HUB_OFFLINE, _get_model_file
|
39 |
+
from ..utils.deprecation_utils import deprecate
|
40 |
+
from ..utils.import_utils import is_safetensors_available, is_transformers_available
|
41 |
+
from ..utils.logging import get_logger
|
42 |
+
|
43 |
+
if is_safetensors_available():
|
44 |
+
import safetensors
|
45 |
+
|
46 |
+
if is_transformers_available():
|
47 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
48 |
+
|
49 |
+
|
50 |
+
logger = get_logger(__name__)
|
51 |
+
|
52 |
+
TEXT_ENCODER_NAME = "text_encoder"
|
53 |
+
UNET_NAME = "unet"
|
54 |
+
|
55 |
+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
56 |
+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
57 |
+
|
58 |
+
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
59 |
+
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
60 |
+
|
61 |
+
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
62 |
+
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
63 |
+
|
64 |
+
|
65 |
+
class AttnProcsLayers(torch.nn.Module):
|
66 |
+
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
67 |
+
super().__init__()
|
68 |
+
self.layers = torch.nn.ModuleList(state_dict.values())
|
69 |
+
self.mapping = dict(enumerate(state_dict.keys()))
|
70 |
+
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
71 |
+
|
72 |
+
# .processor for unet, .self_attn for text encoder
|
73 |
+
self.split_keys = [".processor", ".self_attn"]
|
74 |
+
|
75 |
+
# we add a hook to state_dict() and load_state_dict() so that the
|
76 |
+
# naming fits with `unet.attn_processors`
|
77 |
+
def map_to(module, state_dict, *args, **kwargs):
|
78 |
+
new_state_dict = {}
|
79 |
+
for key, value in state_dict.items():
|
80 |
+
num = int(key.split(".")[1]) # 0 is always "layers"
|
81 |
+
new_key = key.replace(f"layers.{num}", module.mapping[num])
|
82 |
+
new_state_dict[new_key] = value
|
83 |
+
|
84 |
+
return new_state_dict
|
85 |
+
|
86 |
+
def remap_key(key, state_dict):
|
87 |
+
for k in self.split_keys:
|
88 |
+
if k in key:
|
89 |
+
return key.split(k)[0] + k
|
90 |
+
|
91 |
+
raise ValueError(
|
92 |
+
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
|
93 |
+
)
|
94 |
+
|
95 |
+
def map_from(module, state_dict, *args, **kwargs):
|
96 |
+
all_keys = list(state_dict.keys())
|
97 |
+
for key in all_keys:
|
98 |
+
replace_key = remap_key(key, state_dict)
|
99 |
+
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
100 |
+
state_dict[new_key] = state_dict[key]
|
101 |
+
del state_dict[key]
|
102 |
+
|
103 |
+
self._register_state_dict_hook(map_to)
|
104 |
+
self._register_load_state_dict_pre_hook(map_from, with_module=True)
|
105 |
+
|
106 |
+
|
107 |
+
class UNet2DConditionLoadersMixin:
|
108 |
+
text_encoder_name = TEXT_ENCODER_NAME
|
109 |
+
unet_name = UNET_NAME
|
110 |
+
|
111 |
+
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
112 |
+
r"""
|
113 |
+
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
114 |
+
defined in
|
115 |
+
[`cross_attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
|
116 |
+
and be a `torch.nn.Module` class.
|
117 |
+
|
118 |
+
Parameters:
|
119 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
120 |
+
Can be either:
|
121 |
+
|
122 |
+
- A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
123 |
+
the Hub.
|
124 |
+
- A path to a directory (for example `./my_model_directory`) containing the model weights saved
|
125 |
+
with [`ModelMixin.save_pretrained`].
|
126 |
+
- A [torch state
|
127 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
128 |
+
|
129 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
130 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
131 |
+
is not used.
|
132 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
133 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
134 |
+
cached versions if they exist.
|
135 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
136 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
137 |
+
incompletely downloaded files are deleted.
|
138 |
+
proxies (`Dict[str, str]`, *optional*):
|
139 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
140 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
141 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
142 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
143 |
+
won't be downloaded from the Hub.
|
144 |
+
use_auth_token (`str` or *bool*, *optional*):
|
145 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
146 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
147 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
148 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
149 |
+
allowed by Git.
|
150 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
151 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
152 |
+
mirror (`str`, *optional*):
|
153 |
+
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
154 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
155 |
+
information.
|
156 |
+
|
157 |
+
"""
|
158 |
+
|
159 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
160 |
+
force_download = kwargs.pop("force_download", False)
|
161 |
+
resume_download = kwargs.pop("resume_download", False)
|
162 |
+
proxies = kwargs.pop("proxies", None)
|
163 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
164 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
165 |
+
revision = kwargs.pop("revision", None)
|
166 |
+
subfolder = kwargs.pop("subfolder", None)
|
167 |
+
weight_name = kwargs.pop("weight_name", None)
|
168 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
169 |
+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
170 |
+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
171 |
+
network_alpha = kwargs.pop("network_alpha", None)
|
172 |
+
|
173 |
+
if use_safetensors and not is_safetensors_available():
|
174 |
+
raise ValueError(
|
175 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
176 |
+
)
|
177 |
+
|
178 |
+
allow_pickle = False
|
179 |
+
if use_safetensors is None:
|
180 |
+
use_safetensors = is_safetensors_available()
|
181 |
+
allow_pickle = True
|
182 |
+
|
183 |
+
user_agent = {
|
184 |
+
"file_type": "attn_procs_weights",
|
185 |
+
"framework": "pytorch",
|
186 |
+
}
|
187 |
+
|
188 |
+
model_file = None
|
189 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
190 |
+
# Let's first try to load .safetensors weights
|
191 |
+
if (use_safetensors and weight_name is None) or (
|
192 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
193 |
+
):
|
194 |
+
try:
|
195 |
+
model_file = _get_model_file(
|
196 |
+
pretrained_model_name_or_path_or_dict,
|
197 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
198 |
+
cache_dir=cache_dir,
|
199 |
+
force_download=force_download,
|
200 |
+
resume_download=resume_download,
|
201 |
+
proxies=proxies,
|
202 |
+
local_files_only=local_files_only,
|
203 |
+
use_auth_token=use_auth_token,
|
204 |
+
revision=revision,
|
205 |
+
subfolder=subfolder,
|
206 |
+
user_agent=user_agent,
|
207 |
+
)
|
208 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
209 |
+
except IOError as e:
|
210 |
+
if not allow_pickle:
|
211 |
+
raise e
|
212 |
+
# try loading non-safetensors weights
|
213 |
+
pass
|
214 |
+
if model_file is None:
|
215 |
+
model_file = _get_model_file(
|
216 |
+
pretrained_model_name_or_path_or_dict,
|
217 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
218 |
+
cache_dir=cache_dir,
|
219 |
+
force_download=force_download,
|
220 |
+
resume_download=resume_download,
|
221 |
+
proxies=proxies,
|
222 |
+
local_files_only=local_files_only,
|
223 |
+
use_auth_token=use_auth_token,
|
224 |
+
revision=revision,
|
225 |
+
subfolder=subfolder,
|
226 |
+
user_agent=user_agent,
|
227 |
+
)
|
228 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
229 |
+
else:
|
230 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
231 |
+
|
232 |
+
# fill attn processors
|
233 |
+
attn_processors = {}
|
234 |
+
|
235 |
+
is_lora = all("lora" in k for k in state_dict.keys())
|
236 |
+
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
237 |
+
|
238 |
+
if is_lora:
|
239 |
+
is_new_lora_format = all(
|
240 |
+
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
241 |
+
)
|
242 |
+
if is_new_lora_format:
|
243 |
+
# Strip the `"unet"` prefix.
|
244 |
+
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
245 |
+
if is_text_encoder_present:
|
246 |
+
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
247 |
+
warnings.warn(warn_message)
|
248 |
+
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
249 |
+
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
250 |
+
|
251 |
+
lora_grouped_dict = defaultdict(dict)
|
252 |
+
for key, value in state_dict.items():
|
253 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
254 |
+
lora_grouped_dict[attn_processor_key][sub_key] = value
|
255 |
+
|
256 |
+
for key, value_dict in lora_grouped_dict.items():
|
257 |
+
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
258 |
+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
259 |
+
|
260 |
+
attn_processor = self
|
261 |
+
for sub_key in key.split("."):
|
262 |
+
attn_processor = getattr(attn_processor, sub_key)
|
263 |
+
|
264 |
+
if isinstance(
|
265 |
+
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
|
266 |
+
):
|
267 |
+
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
|
268 |
+
attn_processor_class = LoRAAttnAddedKVProcessor
|
269 |
+
else:
|
270 |
+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
271 |
+
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
|
272 |
+
attn_processor_class = LoRAXFormersAttnProcessor
|
273 |
+
else:
|
274 |
+
attn_processor_class = (
|
275 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
276 |
+
)
|
277 |
+
|
278 |
+
attn_processors[key] = attn_processor_class(
|
279 |
+
hidden_size=hidden_size,
|
280 |
+
cross_attention_dim=cross_attention_dim,
|
281 |
+
rank=rank,
|
282 |
+
network_alpha=network_alpha,
|
283 |
+
)
|
284 |
+
attn_processors[key].load_state_dict(value_dict)
|
285 |
+
elif is_custom_diffusion:
|
286 |
+
custom_diffusion_grouped_dict = defaultdict(dict)
|
287 |
+
for key, value in state_dict.items():
|
288 |
+
if len(value) == 0:
|
289 |
+
custom_diffusion_grouped_dict[key] = {}
|
290 |
+
else:
|
291 |
+
if "to_out" in key:
|
292 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
293 |
+
else:
|
294 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
295 |
+
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
296 |
+
|
297 |
+
for key, value_dict in custom_diffusion_grouped_dict.items():
|
298 |
+
if len(value_dict) == 0:
|
299 |
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
300 |
+
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
304 |
+
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
305 |
+
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
306 |
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
307 |
+
train_kv=True,
|
308 |
+
train_q_out=train_q_out,
|
309 |
+
hidden_size=hidden_size,
|
310 |
+
cross_attention_dim=cross_attention_dim,
|
311 |
+
)
|
312 |
+
attn_processors[key].load_state_dict(value_dict)
|
313 |
+
else:
|
314 |
+
raise ValueError(
|
315 |
+
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
316 |
+
)
|
317 |
+
|
318 |
+
# set correct dtype & device
|
319 |
+
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
320 |
+
|
321 |
+
# set layers
|
322 |
+
self.set_attn_processor(attn_processors)
|
323 |
+
|
324 |
+
def save_attn_procs(
|
325 |
+
self,
|
326 |
+
save_directory: Union[str, os.PathLike],
|
327 |
+
is_main_process: bool = True,
|
328 |
+
weight_name: str = None,
|
329 |
+
save_function: Callable = None,
|
330 |
+
safe_serialization: bool = False,
|
331 |
+
**kwargs,
|
332 |
+
):
|
333 |
+
r"""
|
334 |
+
Save an attention processor to a directory so that it can be reloaded using the
|
335 |
+
[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
|
336 |
+
|
337 |
+
Arguments:
|
338 |
+
save_directory (`str` or `os.PathLike`):
|
339 |
+
Directory to save an attention processor to. Will be created if it doesn't exist.
|
340 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
341 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
342 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
343 |
+
process to avoid race conditions.
|
344 |
+
save_function (`Callable`):
|
345 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
346 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
347 |
+
`DIFFUSERS_SAVE_MODE`.
|
348 |
+
|
349 |
+
"""
|
350 |
+
weight_name = weight_name or deprecate(
|
351 |
+
"weights_name",
|
352 |
+
"0.20.0",
|
353 |
+
"`weights_name` is deprecated, please use `weight_name` instead.",
|
354 |
+
take_from=kwargs,
|
355 |
+
)
|
356 |
+
if os.path.isfile(save_directory):
|
357 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
358 |
+
return
|
359 |
+
|
360 |
+
if save_function is None:
|
361 |
+
if safe_serialization:
|
362 |
+
|
363 |
+
def save_function(weights, filename):
|
364 |
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
365 |
+
|
366 |
+
else:
|
367 |
+
save_function = torch.save
|
368 |
+
|
369 |
+
os.makedirs(save_directory, exist_ok=True)
|
370 |
+
|
371 |
+
is_custom_diffusion = any(
|
372 |
+
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
373 |
+
for (_, x) in self.attn_processors.items()
|
374 |
+
)
|
375 |
+
if is_custom_diffusion:
|
376 |
+
model_to_save = AttnProcsLayers(
|
377 |
+
{
|
378 |
+
y: x
|
379 |
+
for (y, x) in self.attn_processors.items()
|
380 |
+
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
381 |
+
}
|
382 |
+
)
|
383 |
+
state_dict = model_to_save.state_dict()
|
384 |
+
for name, attn in self.attn_processors.items():
|
385 |
+
if len(attn.state_dict()) == 0:
|
386 |
+
state_dict[name] = {}
|
387 |
+
else:
|
388 |
+
model_to_save = AttnProcsLayers(self.attn_processors)
|
389 |
+
state_dict = model_to_save.state_dict()
|
390 |
+
|
391 |
+
if weight_name is None:
|
392 |
+
if safe_serialization:
|
393 |
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
394 |
+
else:
|
395 |
+
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
396 |
+
|
397 |
+
# Save the model
|
398 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
399 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
400 |
+
|
401 |
+
|
402 |
+
class TextualInversionLoaderMixin:
|
403 |
+
r"""
|
404 |
+
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
405 |
+
"""
|
406 |
+
|
407 |
+
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
408 |
+
r"""
|
409 |
+
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
410 |
+
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
411 |
+
inversion token or if the textual inversion token is a single vector, the input prompt is returned.
|
412 |
+
|
413 |
+
Parameters:
|
414 |
+
prompt (`str` or list of `str`):
|
415 |
+
The prompt or prompts to guide the image generation.
|
416 |
+
tokenizer (`PreTrainedTokenizer`):
|
417 |
+
The tokenizer responsible for encoding the prompt into input tokens.
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
`str` or list of `str`: The converted prompt
|
421 |
+
"""
|
422 |
+
if not isinstance(prompt, List):
|
423 |
+
prompts = [prompt]
|
424 |
+
else:
|
425 |
+
prompts = prompt
|
426 |
+
|
427 |
+
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
428 |
+
|
429 |
+
if not isinstance(prompt, List):
|
430 |
+
return prompts[0]
|
431 |
+
|
432 |
+
return prompts
|
433 |
+
|
434 |
+
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
435 |
+
r"""
|
436 |
+
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
437 |
+
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
438 |
+
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
439 |
+
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
|
440 |
+
|
441 |
+
Parameters:
|
442 |
+
prompt (`str`):
|
443 |
+
The prompt to guide the image generation.
|
444 |
+
tokenizer (`PreTrainedTokenizer`):
|
445 |
+
The tokenizer responsible for encoding the prompt into input tokens.
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
`str`: The converted prompt
|
449 |
+
"""
|
450 |
+
tokens = tokenizer.tokenize(prompt)
|
451 |
+
unique_tokens = set(tokens)
|
452 |
+
for token in unique_tokens:
|
453 |
+
if token in tokenizer.added_tokens_encoder:
|
454 |
+
replacement = token
|
455 |
+
i = 1
|
456 |
+
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
457 |
+
replacement += f" {token}_{i}"
|
458 |
+
i += 1
|
459 |
+
|
460 |
+
prompt = prompt.replace(token, replacement)
|
461 |
+
|
462 |
+
return prompt
|
463 |
+
|
464 |
+
def load_textual_inversion(
|
465 |
+
self,
|
466 |
+
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
467 |
+
token: Optional[Union[str, List[str]]] = None,
|
468 |
+
**kwargs,
|
469 |
+
):
|
470 |
+
r"""
|
471 |
+
Load textual inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
|
472 |
+
Automatic1111 formats are supported).
|
473 |
+
|
474 |
+
Parameters:
|
475 |
+
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
476 |
+
Can be either one of the following or a list of them:
|
477 |
+
|
478 |
+
- A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
|
479 |
+
pretrained model hosted on the Hub.
|
480 |
+
- A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
|
481 |
+
inversion weights.
|
482 |
+
- A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
|
483 |
+
- A [torch state
|
484 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
485 |
+
|
486 |
+
token (`str` or `List[str]`, *optional*):
|
487 |
+
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
488 |
+
list, then `token` must also be a list of equal length.
|
489 |
+
weight_name (`str`, *optional*):
|
490 |
+
Name of a custom weight file. This should be used when:
|
491 |
+
|
492 |
+
- The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
|
493 |
+
name such as `text_inv.bin`.
|
494 |
+
- The saved textual inversion file is in the Automatic1111 format.
|
495 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
496 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
497 |
+
is not used.
|
498 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
499 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
500 |
+
cached versions if they exist.
|
501 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
502 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
503 |
+
incompletely downloaded files are deleted.
|
504 |
+
proxies (`Dict[str, str]`, *optional*):
|
505 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
506 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
507 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
508 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
509 |
+
won't be downloaded from the Hub.
|
510 |
+
use_auth_token (`str` or *bool*, *optional*):
|
511 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
512 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
513 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
514 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
515 |
+
allowed by Git.
|
516 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
517 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
518 |
+
mirror (`str`, *optional*):
|
519 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
520 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
521 |
+
information.
|
522 |
+
|
523 |
+
Example:
|
524 |
+
|
525 |
+
To load a textual inversion embedding vector in 🤗 Diffusers format:
|
526 |
+
|
527 |
+
```py
|
528 |
+
from diffusers import StableDiffusionPipeline
|
529 |
+
import torch
|
530 |
+
|
531 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
532 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
533 |
+
|
534 |
+
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
|
535 |
+
|
536 |
+
prompt = "A <cat-toy> backpack"
|
537 |
+
|
538 |
+
image = pipe(prompt, num_inference_steps=50).images[0]
|
539 |
+
image.save("cat-backpack.png")
|
540 |
+
```
|
541 |
+
|
542 |
+
To load a textual inversion embedding vector in Automatic1111 format, make sure to download the vector first
|
543 |
+
(for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
|
544 |
+
locally:
|
545 |
+
|
546 |
+
```py
|
547 |
+
from diffusers import StableDiffusionPipeline
|
548 |
+
import torch
|
549 |
+
|
550 |
+
model_id = "runwayml/stable-diffusion-v1-5"
|
551 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
|
552 |
+
|
553 |
+
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
|
554 |
+
|
555 |
+
prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
|
556 |
+
|
557 |
+
image = pipe(prompt, num_inference_steps=50).images[0]
|
558 |
+
image.save("character.png")
|
559 |
+
```
|
560 |
+
|
561 |
+
"""
|
562 |
+
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
563 |
+
raise ValueError(
|
564 |
+
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
|
565 |
+
f" `{self.load_textual_inversion.__name__}`"
|
566 |
+
)
|
567 |
+
|
568 |
+
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
|
569 |
+
raise ValueError(
|
570 |
+
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
|
571 |
+
f" `{self.load_textual_inversion.__name__}`"
|
572 |
+
)
|
573 |
+
|
574 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
575 |
+
force_download = kwargs.pop("force_download", False)
|
576 |
+
resume_download = kwargs.pop("resume_download", False)
|
577 |
+
proxies = kwargs.pop("proxies", None)
|
578 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
579 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
580 |
+
revision = kwargs.pop("revision", None)
|
581 |
+
subfolder = kwargs.pop("subfolder", None)
|
582 |
+
weight_name = kwargs.pop("weight_name", None)
|
583 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
584 |
+
|
585 |
+
if use_safetensors and not is_safetensors_available():
|
586 |
+
raise ValueError(
|
587 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
588 |
+
)
|
589 |
+
|
590 |
+
allow_pickle = False
|
591 |
+
if use_safetensors is None:
|
592 |
+
use_safetensors = is_safetensors_available()
|
593 |
+
allow_pickle = True
|
594 |
+
|
595 |
+
user_agent = {
|
596 |
+
"file_type": "text_inversion",
|
597 |
+
"framework": "pytorch",
|
598 |
+
}
|
599 |
+
|
600 |
+
if not isinstance(pretrained_model_name_or_path, list):
|
601 |
+
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
602 |
+
else:
|
603 |
+
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
604 |
+
|
605 |
+
if isinstance(token, str):
|
606 |
+
tokens = [token]
|
607 |
+
elif token is None:
|
608 |
+
tokens = [None] * len(pretrained_model_name_or_paths)
|
609 |
+
else:
|
610 |
+
tokens = token
|
611 |
+
|
612 |
+
if len(pretrained_model_name_or_paths) != len(tokens):
|
613 |
+
raise ValueError(
|
614 |
+
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
615 |
+
f"Make sure both lists have the same length."
|
616 |
+
)
|
617 |
+
|
618 |
+
valid_tokens = [t for t in tokens if t is not None]
|
619 |
+
if len(set(valid_tokens)) < len(valid_tokens):
|
620 |
+
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
621 |
+
|
622 |
+
token_ids_and_embeddings = []
|
623 |
+
|
624 |
+
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
625 |
+
if not isinstance(pretrained_model_name_or_path, dict):
|
626 |
+
# 1. Load textual inversion file
|
627 |
+
model_file = None
|
628 |
+
# Let's first try to load .safetensors weights
|
629 |
+
if (use_safetensors and weight_name is None) or (
|
630 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
631 |
+
):
|
632 |
+
try:
|
633 |
+
model_file = _get_model_file(
|
634 |
+
pretrained_model_name_or_path,
|
635 |
+
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
636 |
+
cache_dir=cache_dir,
|
637 |
+
force_download=force_download,
|
638 |
+
resume_download=resume_download,
|
639 |
+
proxies=proxies,
|
640 |
+
local_files_only=local_files_only,
|
641 |
+
use_auth_token=use_auth_token,
|
642 |
+
revision=revision,
|
643 |
+
subfolder=subfolder,
|
644 |
+
user_agent=user_agent,
|
645 |
+
)
|
646 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
647 |
+
except Exception as e:
|
648 |
+
if not allow_pickle:
|
649 |
+
raise e
|
650 |
+
|
651 |
+
model_file = None
|
652 |
+
|
653 |
+
if model_file is None:
|
654 |
+
model_file = _get_model_file(
|
655 |
+
pretrained_model_name_or_path,
|
656 |
+
weights_name=weight_name or TEXT_INVERSION_NAME,
|
657 |
+
cache_dir=cache_dir,
|
658 |
+
force_download=force_download,
|
659 |
+
resume_download=resume_download,
|
660 |
+
proxies=proxies,
|
661 |
+
local_files_only=local_files_only,
|
662 |
+
use_auth_token=use_auth_token,
|
663 |
+
revision=revision,
|
664 |
+
subfolder=subfolder,
|
665 |
+
user_agent=user_agent,
|
666 |
+
)
|
667 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
668 |
+
else:
|
669 |
+
state_dict = pretrained_model_name_or_path
|
670 |
+
|
671 |
+
# 2. Load token and embedding correcly from file
|
672 |
+
loaded_token = None
|
673 |
+
if isinstance(state_dict, torch.Tensor):
|
674 |
+
if token is None:
|
675 |
+
raise ValueError(
|
676 |
+
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
677 |
+
)
|
678 |
+
embedding = state_dict
|
679 |
+
elif len(state_dict) == 1:
|
680 |
+
# diffusers
|
681 |
+
loaded_token, embedding = next(iter(state_dict.items()))
|
682 |
+
elif "string_to_param" in state_dict:
|
683 |
+
# A1111
|
684 |
+
loaded_token = state_dict["name"]
|
685 |
+
embedding = state_dict["string_to_param"]["*"]
|
686 |
+
|
687 |
+
if token is not None and loaded_token != token:
|
688 |
+
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
689 |
+
else:
|
690 |
+
token = loaded_token
|
691 |
+
|
692 |
+
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
|
693 |
+
|
694 |
+
# 3. Make sure we don't mess up the tokenizer or text encoder
|
695 |
+
vocab = self.tokenizer.get_vocab()
|
696 |
+
if token in vocab:
|
697 |
+
raise ValueError(
|
698 |
+
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
699 |
+
)
|
700 |
+
elif f"{token}_1" in vocab:
|
701 |
+
multi_vector_tokens = [token]
|
702 |
+
i = 1
|
703 |
+
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
|
704 |
+
multi_vector_tokens.append(f"{token}_{i}")
|
705 |
+
i += 1
|
706 |
+
|
707 |
+
raise ValueError(
|
708 |
+
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
709 |
+
)
|
710 |
+
|
711 |
+
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
712 |
+
|
713 |
+
if is_multi_vector:
|
714 |
+
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
715 |
+
embeddings = [e for e in embedding] # noqa: C416
|
716 |
+
else:
|
717 |
+
tokens = [token]
|
718 |
+
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
719 |
+
|
720 |
+
# add tokens and get ids
|
721 |
+
self.tokenizer.add_tokens(tokens)
|
722 |
+
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
723 |
+
token_ids_and_embeddings += zip(token_ids, embeddings)
|
724 |
+
|
725 |
+
logger.info(f"Loaded textual inversion embedding for {token}.")
|
726 |
+
|
727 |
+
# resize token embeddings and set all new embeddings
|
728 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
729 |
+
for token_id, embedding in token_ids_and_embeddings:
|
730 |
+
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
731 |
+
|
732 |
+
|
733 |
+
class LoraLoaderMixin:
|
734 |
+
r"""
|
735 |
+
Load LoRA layers into [`UNet2DConditionModel`] and
|
736 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
737 |
+
"""
|
738 |
+
text_encoder_name = TEXT_ENCODER_NAME
|
739 |
+
unet_name = UNET_NAME
|
740 |
+
|
741 |
+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
742 |
+
r"""
|
743 |
+
Load pretrained LoRA attention processor layers into [`UNet2DConditionModel`] and
|
744 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
745 |
+
|
746 |
+
Parameters:
|
747 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
748 |
+
Can be either:
|
749 |
+
|
750 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
751 |
+
the Hub.
|
752 |
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
753 |
+
with [`ModelMixin.save_pretrained`].
|
754 |
+
- A [torch state
|
755 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
756 |
+
|
757 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
758 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
759 |
+
is not used.
|
760 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
761 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
762 |
+
cached versions if they exist.
|
763 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
764 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
765 |
+
incompletely downloaded files are deleted.
|
766 |
+
proxies (`Dict[str, str]`, *optional*):
|
767 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
768 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
769 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
770 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
771 |
+
won't be downloaded from the Hub.
|
772 |
+
use_auth_token (`str` or *bool*, *optional*):
|
773 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
774 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
775 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
776 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
777 |
+
allowed by Git.
|
778 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
779 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
780 |
+
mirror (`str`, *optional*):
|
781 |
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
782 |
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
783 |
+
information.
|
784 |
+
|
785 |
+
"""
|
786 |
+
# Load the main state dict first which has the LoRA layers for either of
|
787 |
+
# UNet and text encoder or both.
|
788 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
789 |
+
force_download = kwargs.pop("force_download", False)
|
790 |
+
resume_download = kwargs.pop("resume_download", False)
|
791 |
+
proxies = kwargs.pop("proxies", None)
|
792 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
793 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
794 |
+
revision = kwargs.pop("revision", None)
|
795 |
+
subfolder = kwargs.pop("subfolder", None)
|
796 |
+
weight_name = kwargs.pop("weight_name", None)
|
797 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
798 |
+
|
799 |
+
# set lora scale to a reasonable default
|
800 |
+
self._lora_scale = 1.0
|
801 |
+
|
802 |
+
if use_safetensors and not is_safetensors_available():
|
803 |
+
raise ValueError(
|
804 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
805 |
+
)
|
806 |
+
|
807 |
+
allow_pickle = False
|
808 |
+
if use_safetensors is None:
|
809 |
+
use_safetensors = is_safetensors_available()
|
810 |
+
allow_pickle = True
|
811 |
+
|
812 |
+
user_agent = {
|
813 |
+
"file_type": "attn_procs_weights",
|
814 |
+
"framework": "pytorch",
|
815 |
+
}
|
816 |
+
|
817 |
+
model_file = None
|
818 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
819 |
+
# Let's first try to load .safetensors weights
|
820 |
+
if (use_safetensors and weight_name is None) or (
|
821 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
822 |
+
):
|
823 |
+
try:
|
824 |
+
model_file = _get_model_file(
|
825 |
+
pretrained_model_name_or_path_or_dict,
|
826 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
827 |
+
cache_dir=cache_dir,
|
828 |
+
force_download=force_download,
|
829 |
+
resume_download=resume_download,
|
830 |
+
proxies=proxies,
|
831 |
+
local_files_only=local_files_only,
|
832 |
+
use_auth_token=use_auth_token,
|
833 |
+
revision=revision,
|
834 |
+
subfolder=subfolder,
|
835 |
+
user_agent=user_agent,
|
836 |
+
)
|
837 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
838 |
+
except IOError as e:
|
839 |
+
if not allow_pickle:
|
840 |
+
raise e
|
841 |
+
# try loading non-safetensors weights
|
842 |
+
pass
|
843 |
+
if model_file is None:
|
844 |
+
model_file = _get_model_file(
|
845 |
+
pretrained_model_name_or_path_or_dict,
|
846 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
847 |
+
cache_dir=cache_dir,
|
848 |
+
force_download=force_download,
|
849 |
+
resume_download=resume_download,
|
850 |
+
proxies=proxies,
|
851 |
+
local_files_only=local_files_only,
|
852 |
+
use_auth_token=use_auth_token,
|
853 |
+
revision=revision,
|
854 |
+
subfolder=subfolder,
|
855 |
+
user_agent=user_agent,
|
856 |
+
)
|
857 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
858 |
+
else:
|
859 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
860 |
+
|
861 |
+
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
|
862 |
+
network_alpha = None
|
863 |
+
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
|
864 |
+
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
|
865 |
+
|
866 |
+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
867 |
+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
868 |
+
# their prefixes.
|
869 |
+
keys = list(state_dict.keys())
|
870 |
+
if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys):
|
871 |
+
# Load the layers corresponding to UNet.
|
872 |
+
unet_keys = [k for k in keys if k.startswith(self.unet_name)]
|
873 |
+
logger.info(f"Loading {self.unet_name}.")
|
874 |
+
unet_lora_state_dict = {
|
875 |
+
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
876 |
+
}
|
877 |
+
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
|
878 |
+
|
879 |
+
# Load the layers corresponding to text encoder and make necessary adjustments.
|
880 |
+
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
|
881 |
+
text_encoder_lora_state_dict = {
|
882 |
+
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
883 |
+
}
|
884 |
+
if len(text_encoder_lora_state_dict) > 0:
|
885 |
+
logger.info(f"Loading {self.text_encoder_name}.")
|
886 |
+
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
|
887 |
+
text_encoder_lora_state_dict, network_alpha=network_alpha
|
888 |
+
)
|
889 |
+
self._modify_text_encoder(attn_procs_text_encoder)
|
890 |
+
|
891 |
+
# save lora attn procs of text encoder so that it can be easily retrieved
|
892 |
+
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
|
893 |
+
|
894 |
+
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
895 |
+
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
896 |
+
elif not all(
|
897 |
+
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
898 |
+
):
|
899 |
+
self.unet.load_attn_procs(state_dict)
|
900 |
+
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
|
901 |
+
warnings.warn(warn_message)
|
902 |
+
|
903 |
+
@property
|
904 |
+
def lora_scale(self) -> float:
|
905 |
+
# property function that returns the lora scale which can be set at run time by the pipeline.
|
906 |
+
# if _lora_scale has not been set, return 1
|
907 |
+
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
908 |
+
|
909 |
+
@property
|
910 |
+
def text_encoder_lora_attn_procs(self):
|
911 |
+
if hasattr(self, "_text_encoder_lora_attn_procs"):
|
912 |
+
return self._text_encoder_lora_attn_procs
|
913 |
+
return
|
914 |
+
|
915 |
+
def _remove_text_encoder_monkey_patch(self):
|
916 |
+
# Loop over the CLIPAttention module of text_encoder
|
917 |
+
for name, attn_module in self.text_encoder.named_modules():
|
918 |
+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
919 |
+
# Loop over the LoRA layers
|
920 |
+
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
921 |
+
# Retrieve the q/k/v/out projection of CLIPAttention
|
922 |
+
module = attn_module.get_submodule(text_encoder_attr)
|
923 |
+
if hasattr(module, "old_forward"):
|
924 |
+
# restore original `forward` to remove monkey-patch
|
925 |
+
module.forward = module.old_forward
|
926 |
+
delattr(module, "old_forward")
|
927 |
+
|
928 |
+
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
929 |
+
r"""
|
930 |
+
Monkey-patches the forward passes of attention modules of the text encoder.
|
931 |
+
|
932 |
+
Parameters:
|
933 |
+
attn_processors: Dict[str, `LoRAAttnProcessor`]:
|
934 |
+
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
|
935 |
+
"""
|
936 |
+
|
937 |
+
# First, remove any monkey-patch that might have been applied before
|
938 |
+
self._remove_text_encoder_monkey_patch()
|
939 |
+
|
940 |
+
# Loop over the CLIPAttention module of text_encoder
|
941 |
+
for name, attn_module in self.text_encoder.named_modules():
|
942 |
+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
943 |
+
# Loop over the LoRA layers
|
944 |
+
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
945 |
+
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
|
946 |
+
module = attn_module.get_submodule(text_encoder_attr)
|
947 |
+
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
|
948 |
+
|
949 |
+
# save old_forward to module that can be used to remove monkey-patch
|
950 |
+
old_forward = module.old_forward = module.forward
|
951 |
+
|
952 |
+
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
|
953 |
+
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
|
954 |
+
def make_new_forward(old_forward, lora_layer):
|
955 |
+
def new_forward(x):
|
956 |
+
result = old_forward(x) + self.lora_scale * lora_layer(x)
|
957 |
+
return result
|
958 |
+
|
959 |
+
return new_forward
|
960 |
+
|
961 |
+
# Monkey-patch.
|
962 |
+
module.forward = make_new_forward(old_forward, lora_layer)
|
963 |
+
|
964 |
+
@property
|
965 |
+
def _lora_attn_processor_attr_to_text_encoder_attr(self):
|
966 |
+
return {
|
967 |
+
"to_q_lora": "q_proj",
|
968 |
+
"to_k_lora": "k_proj",
|
969 |
+
"to_v_lora": "v_proj",
|
970 |
+
"to_out_lora": "out_proj",
|
971 |
+
}
|
972 |
+
|
973 |
+
def _load_text_encoder_attn_procs(
|
974 |
+
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
|
975 |
+
):
|
976 |
+
r"""
|
977 |
+
Load pretrained attention processor layers for
|
978 |
+
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
979 |
+
|
980 |
+
<Tip warning={true}>
|
981 |
+
|
982 |
+
This function is experimental and might change in the future.
|
983 |
+
|
984 |
+
</Tip>
|
985 |
+
|
986 |
+
Parameters:
|
987 |
+
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
988 |
+
Can be either:
|
989 |
+
|
990 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
991 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
992 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
993 |
+
`./my_model_directory/`.
|
994 |
+
- A [torch state
|
995 |
+
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
996 |
+
|
997 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
998 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
999 |
+
standard cache should not be used.
|
1000 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1001 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1002 |
+
cached versions if they exist.
|
1003 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1004 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
1005 |
+
file exists.
|
1006 |
+
proxies (`Dict[str, str]`, *optional*):
|
1007 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
1008 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1009 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1010 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
1011 |
+
use_auth_token (`str` or *bool*, *optional*):
|
1012 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
1013 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
1014 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
1015 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
1016 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
1017 |
+
identifier allowed by git.
|
1018 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
1019 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
1020 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
1021 |
+
mirror (`str`, *optional*):
|
1022 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
1023 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
1024 |
+
Please refer to the mirror site for more information.
|
1025 |
+
|
1026 |
+
Returns:
|
1027 |
+
`Dict[name, LoRAAttnProcessor]`: Mapping between the module names and their corresponding
|
1028 |
+
[`LoRAAttnProcessor`].
|
1029 |
+
|
1030 |
+
<Tip>
|
1031 |
+
|
1032 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
1033 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
1034 |
+
|
1035 |
+
</Tip>
|
1036 |
+
"""
|
1037 |
+
|
1038 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1039 |
+
force_download = kwargs.pop("force_download", False)
|
1040 |
+
resume_download = kwargs.pop("resume_download", False)
|
1041 |
+
proxies = kwargs.pop("proxies", None)
|
1042 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1043 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1044 |
+
revision = kwargs.pop("revision", None)
|
1045 |
+
subfolder = kwargs.pop("subfolder", None)
|
1046 |
+
weight_name = kwargs.pop("weight_name", None)
|
1047 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
1048 |
+
network_alpha = kwargs.pop("network_alpha", None)
|
1049 |
+
|
1050 |
+
if use_safetensors and not is_safetensors_available():
|
1051 |
+
raise ValueError(
|
1052 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
1053 |
+
)
|
1054 |
+
|
1055 |
+
allow_pickle = False
|
1056 |
+
if use_safetensors is None:
|
1057 |
+
use_safetensors = is_safetensors_available()
|
1058 |
+
allow_pickle = True
|
1059 |
+
|
1060 |
+
user_agent = {
|
1061 |
+
"file_type": "attn_procs_weights",
|
1062 |
+
"framework": "pytorch",
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
model_file = None
|
1066 |
+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
1067 |
+
# Let's first try to load .safetensors weights
|
1068 |
+
if (use_safetensors and weight_name is None) or (
|
1069 |
+
weight_name is not None and weight_name.endswith(".safetensors")
|
1070 |
+
):
|
1071 |
+
try:
|
1072 |
+
model_file = _get_model_file(
|
1073 |
+
pretrained_model_name_or_path_or_dict,
|
1074 |
+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
1075 |
+
cache_dir=cache_dir,
|
1076 |
+
force_download=force_download,
|
1077 |
+
resume_download=resume_download,
|
1078 |
+
proxies=proxies,
|
1079 |
+
local_files_only=local_files_only,
|
1080 |
+
use_auth_token=use_auth_token,
|
1081 |
+
revision=revision,
|
1082 |
+
subfolder=subfolder,
|
1083 |
+
user_agent=user_agent,
|
1084 |
+
)
|
1085 |
+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
1086 |
+
except IOError as e:
|
1087 |
+
if not allow_pickle:
|
1088 |
+
raise e
|
1089 |
+
# try loading non-safetensors weights
|
1090 |
+
pass
|
1091 |
+
if model_file is None:
|
1092 |
+
model_file = _get_model_file(
|
1093 |
+
pretrained_model_name_or_path_or_dict,
|
1094 |
+
weights_name=weight_name or LORA_WEIGHT_NAME,
|
1095 |
+
cache_dir=cache_dir,
|
1096 |
+
force_download=force_download,
|
1097 |
+
resume_download=resume_download,
|
1098 |
+
proxies=proxies,
|
1099 |
+
local_files_only=local_files_only,
|
1100 |
+
use_auth_token=use_auth_token,
|
1101 |
+
revision=revision,
|
1102 |
+
subfolder=subfolder,
|
1103 |
+
user_agent=user_agent,
|
1104 |
+
)
|
1105 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
1106 |
+
else:
|
1107 |
+
state_dict = pretrained_model_name_or_path_or_dict
|
1108 |
+
|
1109 |
+
# fill attn processors
|
1110 |
+
attn_processors = {}
|
1111 |
+
|
1112 |
+
is_lora = all("lora" in k for k in state_dict.keys())
|
1113 |
+
|
1114 |
+
if is_lora:
|
1115 |
+
lora_grouped_dict = defaultdict(dict)
|
1116 |
+
for key, value in state_dict.items():
|
1117 |
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
1118 |
+
lora_grouped_dict[attn_processor_key][sub_key] = value
|
1119 |
+
|
1120 |
+
for key, value_dict in lora_grouped_dict.items():
|
1121 |
+
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
1122 |
+
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
1123 |
+
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
1124 |
+
|
1125 |
+
attn_processor_class = (
|
1126 |
+
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
1127 |
+
)
|
1128 |
+
attn_processors[key] = attn_processor_class(
|
1129 |
+
hidden_size=hidden_size,
|
1130 |
+
cross_attention_dim=cross_attention_dim,
|
1131 |
+
rank=rank,
|
1132 |
+
network_alpha=network_alpha,
|
1133 |
+
)
|
1134 |
+
attn_processors[key].load_state_dict(value_dict)
|
1135 |
+
|
1136 |
+
else:
|
1137 |
+
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
1138 |
+
|
1139 |
+
# set correct dtype & device
|
1140 |
+
attn_processors = {
|
1141 |
+
k: v.to(device=self.device, dtype=self.text_encoder.dtype) for k, v in attn_processors.items()
|
1142 |
+
}
|
1143 |
+
return attn_processors
|
1144 |
+
|
1145 |
+
@classmethod
|
1146 |
+
def save_lora_weights(
|
1147 |
+
self,
|
1148 |
+
save_directory: Union[str, os.PathLike],
|
1149 |
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1150 |
+
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
1151 |
+
is_main_process: bool = True,
|
1152 |
+
weight_name: str = None,
|
1153 |
+
save_function: Callable = None,
|
1154 |
+
safe_serialization: bool = False,
|
1155 |
+
):
|
1156 |
+
r"""
|
1157 |
+
Save the LoRA parameters corresponding to the UNet and text encoder.
|
1158 |
+
|
1159 |
+
Arguments:
|
1160 |
+
save_directory (`str` or `os.PathLike`):
|
1161 |
+
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
1162 |
+
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
1163 |
+
State dict of the LoRA layers corresponding to the UNet.
|
1164 |
+
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
|
1165 |
+
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
1166 |
+
encoder LoRA state dict because it comes 🤗 Transformers.
|
1167 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
1168 |
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
1169 |
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
1170 |
+
process to avoid race conditions.
|
1171 |
+
save_function (`Callable`):
|
1172 |
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
1173 |
+
replace `torch.save` with another method. Can be configured with the environment variable
|
1174 |
+
`DIFFUSERS_SAVE_MODE`.
|
1175 |
+
"""
|
1176 |
+
if os.path.isfile(save_directory):
|
1177 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
1178 |
+
return
|
1179 |
+
|
1180 |
+
if save_function is None:
|
1181 |
+
if safe_serialization:
|
1182 |
+
|
1183 |
+
def save_function(weights, filename):
|
1184 |
+
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
1185 |
+
|
1186 |
+
else:
|
1187 |
+
save_function = torch.save
|
1188 |
+
|
1189 |
+
os.makedirs(save_directory, exist_ok=True)
|
1190 |
+
|
1191 |
+
# Create a flat dictionary.
|
1192 |
+
state_dict = {}
|
1193 |
+
if unet_lora_layers is not None:
|
1194 |
+
weights = (
|
1195 |
+
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
1196 |
+
)
|
1197 |
+
|
1198 |
+
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
1199 |
+
state_dict.update(unet_lora_state_dict)
|
1200 |
+
|
1201 |
+
if text_encoder_lora_layers is not None:
|
1202 |
+
weights = (
|
1203 |
+
text_encoder_lora_layers.state_dict()
|
1204 |
+
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
1205 |
+
else text_encoder_lora_layers
|
1206 |
+
)
|
1207 |
+
|
1208 |
+
text_encoder_lora_state_dict = {
|
1209 |
+
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
1210 |
+
}
|
1211 |
+
state_dict.update(text_encoder_lora_state_dict)
|
1212 |
+
|
1213 |
+
# Save the model
|
1214 |
+
if weight_name is None:
|
1215 |
+
if safe_serialization:
|
1216 |
+
weight_name = LORA_WEIGHT_NAME_SAFE
|
1217 |
+
else:
|
1218 |
+
weight_name = LORA_WEIGHT_NAME
|
1219 |
+
|
1220 |
+
save_function(state_dict, os.path.join(save_directory, weight_name))
|
1221 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
1222 |
+
|
1223 |
+
def _convert_kohya_lora_to_diffusers(self, state_dict):
|
1224 |
+
unet_state_dict = {}
|
1225 |
+
te_state_dict = {}
|
1226 |
+
network_alpha = None
|
1227 |
+
|
1228 |
+
for key, value in state_dict.items():
|
1229 |
+
if "lora_down" in key:
|
1230 |
+
lora_name = key.split(".")[0]
|
1231 |
+
lora_name_up = lora_name + ".lora_up.weight"
|
1232 |
+
lora_name_alpha = lora_name + ".alpha"
|
1233 |
+
if lora_name_alpha in state_dict:
|
1234 |
+
alpha = state_dict[lora_name_alpha].item()
|
1235 |
+
if network_alpha is None:
|
1236 |
+
network_alpha = alpha
|
1237 |
+
elif network_alpha != alpha:
|
1238 |
+
raise ValueError("Network alpha is not consistent")
|
1239 |
+
|
1240 |
+
if lora_name.startswith("lora_unet_"):
|
1241 |
+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
1242 |
+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
1243 |
+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
1244 |
+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
1245 |
+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
1246 |
+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
1247 |
+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
1248 |
+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
1249 |
+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
1250 |
+
if "transformer_blocks" in diffusers_name:
|
1251 |
+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
1252 |
+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
1253 |
+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
1254 |
+
unet_state_dict[diffusers_name] = value
|
1255 |
+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1256 |
+
elif lora_name.startswith("lora_te_"):
|
1257 |
+
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
|
1258 |
+
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
1259 |
+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
1260 |
+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
1261 |
+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
1262 |
+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
1263 |
+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
1264 |
+
if "self_attn" in diffusers_name:
|
1265 |
+
te_state_dict[diffusers_name] = value
|
1266 |
+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
|
1267 |
+
|
1268 |
+
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
1269 |
+
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
|
1270 |
+
new_state_dict = {**unet_state_dict, **te_state_dict}
|
1271 |
+
return new_state_dict, network_alpha
|
1272 |
+
|
1273 |
+
|
1274 |
+
class FromCkptMixin:
|
1275 |
+
"""
|
1276 |
+
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
1277 |
+
"""
|
1278 |
+
|
1279 |
+
@classmethod
|
1280 |
+
def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
|
1281 |
+
r"""
|
1282 |
+
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` format. The pipeline
|
1283 |
+
is set in evaluation mode (`model.eval()`) by default.
|
1284 |
+
|
1285 |
+
Parameters:
|
1286 |
+
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
1287 |
+
Can be either:
|
1288 |
+
- A link to the `.ckpt` file (for example
|
1289 |
+
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
1290 |
+
- A path to a *file* containing all pipeline weights.
|
1291 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
1292 |
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
1293 |
+
dtype is automatically derived from the model's weights.
|
1294 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
1295 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1296 |
+
cached versions if they exist.
|
1297 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
1298 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
1299 |
+
is not used.
|
1300 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
1301 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
1302 |
+
incompletely downloaded files are deleted.
|
1303 |
+
proxies (`Dict[str, str]`, *optional*):
|
1304 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1305 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
1306 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
1307 |
+
Whether to only load local model weights and configuration files or not. If set to True, the model
|
1308 |
+
won't be downloaded from the Hub.
|
1309 |
+
use_auth_token (`str` or *bool*, *optional*):
|
1310 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
1311 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
1312 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
1313 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
1314 |
+
allowed by Git.
|
1315 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
1316 |
+
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
1317 |
+
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
1318 |
+
weights. If set to `False`, safetensors weights are not loaded.
|
1319 |
+
extract_ema (`bool`, *optional*, defaults to `False`):
|
1320 |
+
Whether to extract the EMA weights or not. Pass `True` to extract the EMA weights which usually yield
|
1321 |
+
higher quality images for inference. Non-EMA weights are usually better to continue finetuning.
|
1322 |
+
upcast_attention (`bool`, *optional*, defaults to `None`):
|
1323 |
+
Whether the attention computation should always be upcasted.
|
1324 |
+
image_size (`int`, *optional*, defaults to 512):
|
1325 |
+
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
1326 |
+
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
1327 |
+
prediction_type (`str`, *optional*):
|
1328 |
+
The prediction type the model was trained on. Use `'epsilon'` for all Stable Diffusion v1 models and
|
1329 |
+
the Stable Diffusion v2 base model. Use `'v_prediction'` for Stable Diffusion v2.
|
1330 |
+
num_in_channels (`int`, *optional*, defaults to `None`):
|
1331 |
+
The number of input channels. If `None`, it will be automatically inferred.
|
1332 |
+
scheduler_type (`str`, *optional*, defaults to `"pndm"`):
|
1333 |
+
Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm",
|
1334 |
+
"ddim"]`.
|
1335 |
+
load_safety_checker (`bool`, *optional*, defaults to `True`):
|
1336 |
+
Whether to load the safety checker or not.
|
1337 |
+
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
|
1338 |
+
An instance of
|
1339 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to use,
|
1340 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
|
1341 |
+
variant. If this parameter is `None`, the function will load a new instance of [CLIP] by itself, if
|
1342 |
+
needed.
|
1343 |
+
tokenizer (`CLIPTokenizer`, *optional*, defaults to `None`):
|
1344 |
+
An instance of
|
1345 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
1346 |
+
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by
|
1347 |
+
itself, if needed.
|
1348 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
1349 |
+
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
1350 |
+
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
1351 |
+
method. See example below for more information.
|
1352 |
+
|
1353 |
+
Examples:
|
1354 |
+
|
1355 |
+
```py
|
1356 |
+
>>> from diffusers import StableDiffusionPipeline
|
1357 |
+
|
1358 |
+
>>> # Download pipeline from huggingface.co and cache.
|
1359 |
+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
1360 |
+
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
1361 |
+
... )
|
1362 |
+
|
1363 |
+
>>> # Download pipeline from local file
|
1364 |
+
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
1365 |
+
>>> pipeline = StableDiffusionPipeline.from_ckpt("./v1-5-pruned-emaonly")
|
1366 |
+
|
1367 |
+
>>> # Enable float16 and move to GPU
|
1368 |
+
>>> pipeline = StableDiffusionPipeline.from_ckpt(
|
1369 |
+
... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
1370 |
+
... torch_dtype=torch.float16,
|
1371 |
+
... )
|
1372 |
+
>>> pipeline.to("cuda")
|
1373 |
+
```
|
1374 |
+
"""
|
1375 |
+
# import here to avoid circular dependency
|
1376 |
+
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
1377 |
+
|
1378 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
1379 |
+
resume_download = kwargs.pop("resume_download", False)
|
1380 |
+
force_download = kwargs.pop("force_download", False)
|
1381 |
+
proxies = kwargs.pop("proxies", None)
|
1382 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
1383 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
1384 |
+
revision = kwargs.pop("revision", None)
|
1385 |
+
extract_ema = kwargs.pop("extract_ema", False)
|
1386 |
+
image_size = kwargs.pop("image_size", 512)
|
1387 |
+
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
1388 |
+
num_in_channels = kwargs.pop("num_in_channels", None)
|
1389 |
+
upcast_attention = kwargs.pop("upcast_attention", None)
|
1390 |
+
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
1391 |
+
prediction_type = kwargs.pop("prediction_type", None)
|
1392 |
+
text_encoder = kwargs.pop("text_encoder", None)
|
1393 |
+
tokenizer = kwargs.pop("tokenizer", None)
|
1394 |
+
|
1395 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1396 |
+
|
1397 |
+
use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False)
|
1398 |
+
|
1399 |
+
pipeline_name = cls.__name__
|
1400 |
+
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
1401 |
+
from_safetensors = file_extension == "safetensors"
|
1402 |
+
|
1403 |
+
if from_safetensors and use_safetensors is False:
|
1404 |
+
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
1405 |
+
|
1406 |
+
# TODO: For now we only support stable diffusion
|
1407 |
+
stable_unclip = None
|
1408 |
+
model_type = None
|
1409 |
+
controlnet = False
|
1410 |
+
|
1411 |
+
if pipeline_name == "StableDiffusionControlNetPipeline":
|
1412 |
+
# Model type will be inferred from the checkpoint.
|
1413 |
+
controlnet = True
|
1414 |
+
elif "StableDiffusion" in pipeline_name:
|
1415 |
+
# Model type will be inferred from the checkpoint.
|
1416 |
+
pass
|
1417 |
+
elif pipeline_name == "StableUnCLIPPipeline":
|
1418 |
+
model_type = "FrozenOpenCLIPEmbedder"
|
1419 |
+
stable_unclip = "txt2img"
|
1420 |
+
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
1421 |
+
model_type = "FrozenOpenCLIPEmbedder"
|
1422 |
+
stable_unclip = "img2img"
|
1423 |
+
elif pipeline_name == "PaintByExamplePipeline":
|
1424 |
+
model_type = "PaintByExample"
|
1425 |
+
elif pipeline_name == "LDMTextToImagePipeline":
|
1426 |
+
model_type = "LDMTextToImage"
|
1427 |
+
else:
|
1428 |
+
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
1429 |
+
|
1430 |
+
# remove huggingface url
|
1431 |
+
for prefix in ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]:
|
1432 |
+
if pretrained_model_link_or_path.startswith(prefix):
|
1433 |
+
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
1434 |
+
|
1435 |
+
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
1436 |
+
ckpt_path = Path(pretrained_model_link_or_path)
|
1437 |
+
if not ckpt_path.is_file():
|
1438 |
+
# get repo_id and (potentially nested) file path of ckpt in repo
|
1439 |
+
repo_id = "/".join(ckpt_path.parts[:2])
|
1440 |
+
file_path = "/".join(ckpt_path.parts[2:])
|
1441 |
+
|
1442 |
+
if file_path.startswith("blob/"):
|
1443 |
+
file_path = file_path[len("blob/") :]
|
1444 |
+
|
1445 |
+
if file_path.startswith("main/"):
|
1446 |
+
file_path = file_path[len("main/") :]
|
1447 |
+
|
1448 |
+
pretrained_model_link_or_path = hf_hub_download(
|
1449 |
+
repo_id,
|
1450 |
+
filename=file_path,
|
1451 |
+
cache_dir=cache_dir,
|
1452 |
+
resume_download=resume_download,
|
1453 |
+
proxies=proxies,
|
1454 |
+
local_files_only=local_files_only,
|
1455 |
+
use_auth_token=use_auth_token,
|
1456 |
+
revision=revision,
|
1457 |
+
force_download=force_download,
|
1458 |
+
)
|
1459 |
+
|
1460 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
1461 |
+
pretrained_model_link_or_path,
|
1462 |
+
pipeline_class=cls,
|
1463 |
+
model_type=model_type,
|
1464 |
+
stable_unclip=stable_unclip,
|
1465 |
+
controlnet=controlnet,
|
1466 |
+
from_safetensors=from_safetensors,
|
1467 |
+
extract_ema=extract_ema,
|
1468 |
+
image_size=image_size,
|
1469 |
+
scheduler_type=scheduler_type,
|
1470 |
+
num_in_channels=num_in_channels,
|
1471 |
+
upcast_attention=upcast_attention,
|
1472 |
+
load_safety_checker=load_safety_checker,
|
1473 |
+
prediction_type=prediction_type,
|
1474 |
+
text_encoder=text_encoder,
|
1475 |
+
tokenizer=tokenizer,
|
1476 |
+
)
|
1477 |
+
|
1478 |
+
if torch_dtype is not None:
|
1479 |
+
pipe.to(torch_dtype=torch_dtype)
|
1480 |
+
|
1481 |
+
return pipe
|
diffusers/models/modeling_utils.py
ADDED
@@ -0,0 +1,978 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import inspect
|
18 |
+
import itertools
|
19 |
+
import os
|
20 |
+
import re
|
21 |
+
from functools import partial
|
22 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import Tensor, device, nn
|
26 |
+
|
27 |
+
from ..utils.constants import (
|
28 |
+
CONFIG_NAME,
|
29 |
+
DIFFUSERS_CACHE,
|
30 |
+
FLAX_WEIGHTS_NAME,
|
31 |
+
SAFETENSORS_WEIGHTS_NAME,
|
32 |
+
WEIGHTS_NAME
|
33 |
+
)
|
34 |
+
from ..utils.hub_utils import (
|
35 |
+
HF_HUB_OFFLINE,
|
36 |
+
_add_variant,
|
37 |
+
_get_model_file
|
38 |
+
)
|
39 |
+
from ..utils.deprecation_utils import deprecate
|
40 |
+
from ..utils.import_utils import (
|
41 |
+
is_accelerate_available,
|
42 |
+
is_safetensors_available,
|
43 |
+
is_torch_version
|
44 |
+
)
|
45 |
+
from ..utils.logging import get_logger
|
46 |
+
|
47 |
+
logger = get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
if is_torch_version(">=", "1.9.0"):
|
51 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
52 |
+
else:
|
53 |
+
_LOW_CPU_MEM_USAGE_DEFAULT = False
|
54 |
+
|
55 |
+
|
56 |
+
if is_accelerate_available():
|
57 |
+
import accelerate
|
58 |
+
from accelerate.utils import set_module_tensor_to_device
|
59 |
+
from accelerate.utils.versions import is_torch_version
|
60 |
+
|
61 |
+
if is_safetensors_available():
|
62 |
+
import safetensors
|
63 |
+
|
64 |
+
|
65 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
66 |
+
try:
|
67 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
68 |
+
return next(parameters_and_buffers).device
|
69 |
+
except StopIteration:
|
70 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
71 |
+
|
72 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
73 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
74 |
+
return tuples
|
75 |
+
|
76 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
77 |
+
first_tuple = next(gen)
|
78 |
+
return first_tuple[1].device
|
79 |
+
|
80 |
+
|
81 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
82 |
+
try:
|
83 |
+
params = tuple(parameter.parameters())
|
84 |
+
if len(params) > 0:
|
85 |
+
return params[0].dtype
|
86 |
+
|
87 |
+
buffers = tuple(parameter.buffers())
|
88 |
+
if len(buffers) > 0:
|
89 |
+
return buffers[0].dtype
|
90 |
+
|
91 |
+
except StopIteration:
|
92 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
93 |
+
|
94 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
95 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
96 |
+
return tuples
|
97 |
+
|
98 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
99 |
+
first_tuple = next(gen)
|
100 |
+
return first_tuple[1].dtype
|
101 |
+
|
102 |
+
|
103 |
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
104 |
+
"""
|
105 |
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
106 |
+
"""
|
107 |
+
try:
|
108 |
+
if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
|
109 |
+
return torch.load(checkpoint_file, map_location="cpu")
|
110 |
+
else:
|
111 |
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
112 |
+
except Exception as e:
|
113 |
+
try:
|
114 |
+
with open(checkpoint_file) as f:
|
115 |
+
if f.read().startswith("version"):
|
116 |
+
raise OSError(
|
117 |
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
118 |
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
119 |
+
"you cloned."
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise ValueError(
|
123 |
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
124 |
+
"model. Make sure you have saved the model properly."
|
125 |
+
) from e
|
126 |
+
except (UnicodeDecodeError, ValueError):
|
127 |
+
raise OSError(
|
128 |
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
129 |
+
f"at '{checkpoint_file}'. "
|
130 |
+
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
131 |
+
)
|
132 |
+
|
133 |
+
|
134 |
+
def _load_state_dict_into_model(model_to_load, state_dict):
|
135 |
+
# Convert old format to new format if needed from a PyTorch state_dict
|
136 |
+
# copy state_dict so _load_from_state_dict can modify it
|
137 |
+
state_dict = state_dict.copy()
|
138 |
+
error_msgs = []
|
139 |
+
|
140 |
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
141 |
+
# so we need to apply the function recursively.
|
142 |
+
def load(module: torch.nn.Module, prefix=""):
|
143 |
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
144 |
+
module._load_from_state_dict(*args)
|
145 |
+
|
146 |
+
for name, child in module._modules.items():
|
147 |
+
if child is not None:
|
148 |
+
load(child, prefix + name + ".")
|
149 |
+
|
150 |
+
load(model_to_load)
|
151 |
+
|
152 |
+
return error_msgs
|
153 |
+
|
154 |
+
|
155 |
+
class ModelMixin(torch.nn.Module):
|
156 |
+
r"""
|
157 |
+
Base class for all models.
|
158 |
+
|
159 |
+
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
|
160 |
+
and saving models.
|
161 |
+
|
162 |
+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
|
163 |
+
[`~models.ModelMixin.save_pretrained`].
|
164 |
+
"""
|
165 |
+
config_name = CONFIG_NAME
|
166 |
+
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
167 |
+
_supports_gradient_checkpointing = False
|
168 |
+
_keys_to_ignore_on_load_unexpected = None
|
169 |
+
|
170 |
+
def __init__(self):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
+
def __getattr__(self, name: str) -> Any:
|
174 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
175 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
|
176 |
+
__getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
|
177 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
178 |
+
"""
|
179 |
+
|
180 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
181 |
+
is_attribute = name in self.__dict__
|
182 |
+
|
183 |
+
if is_in_config and not is_attribute:
|
184 |
+
deprecation_message = (
|
185 |
+
f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. "
|
186 |
+
f"Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
|
187 |
+
)
|
188 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
|
189 |
+
return self._internal_dict[name]
|
190 |
+
|
191 |
+
# call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
192 |
+
return super().__getattr__(name)
|
193 |
+
|
194 |
+
@property
|
195 |
+
def is_gradient_checkpointing(self) -> bool:
|
196 |
+
"""
|
197 |
+
Whether gradient checkpointing is activated for this model or not.
|
198 |
+
|
199 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
200 |
+
activations".
|
201 |
+
"""
|
202 |
+
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
203 |
+
|
204 |
+
def enable_gradient_checkpointing(self):
|
205 |
+
"""
|
206 |
+
Activates gradient checkpointing for the current model.
|
207 |
+
|
208 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
209 |
+
activations".
|
210 |
+
"""
|
211 |
+
if not self._supports_gradient_checkpointing:
|
212 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
213 |
+
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
214 |
+
|
215 |
+
def disable_gradient_checkpointing(self):
|
216 |
+
"""
|
217 |
+
Deactivates gradient checkpointing for the current model.
|
218 |
+
|
219 |
+
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
220 |
+
activations".
|
221 |
+
"""
|
222 |
+
if self._supports_gradient_checkpointing:
|
223 |
+
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
224 |
+
|
225 |
+
def set_use_memory_efficient_attention_xformers(
|
226 |
+
self, valid: bool, attention_op: Optional[Callable] = None
|
227 |
+
) -> None:
|
228 |
+
# Recursively walk through all the children.
|
229 |
+
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
230 |
+
# gets the message
|
231 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
232 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
233 |
+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
234 |
+
|
235 |
+
for child in module.children():
|
236 |
+
fn_recursive_set_mem_eff(child)
|
237 |
+
|
238 |
+
for module in self.children():
|
239 |
+
if isinstance(module, torch.nn.Module):
|
240 |
+
fn_recursive_set_mem_eff(module)
|
241 |
+
|
242 |
+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
243 |
+
r"""
|
244 |
+
Enable memory efficient attention as implemented in xformers.
|
245 |
+
|
246 |
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
|
247 |
+
time. Speed up at training time is not guaranteed.
|
248 |
+
|
249 |
+
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
250 |
+
is used.
|
251 |
+
|
252 |
+
Parameters:
|
253 |
+
attention_op (`Callable`, *optional*):
|
254 |
+
Override the default `None` operator for use as `op` argument to the
|
255 |
+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
256 |
+
function of xFormers.
|
257 |
+
|
258 |
+
Examples:
|
259 |
+
|
260 |
+
```py
|
261 |
+
>>> import torch
|
262 |
+
>>> from diffusers import UNet2DConditionModel
|
263 |
+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
264 |
+
|
265 |
+
>>> model = UNet2DConditionModel.from_pretrained(
|
266 |
+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
267 |
+
... )
|
268 |
+
>>> model = model.to("cuda")
|
269 |
+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
270 |
+
```
|
271 |
+
"""
|
272 |
+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
273 |
+
|
274 |
+
def disable_xformers_memory_efficient_attention(self):
|
275 |
+
r"""
|
276 |
+
Disable memory efficient attention as implemented in xformers.
|
277 |
+
"""
|
278 |
+
self.set_use_memory_efficient_attention_xformers(False)
|
279 |
+
|
280 |
+
def save_pretrained(
|
281 |
+
self,
|
282 |
+
save_directory: Union[str, os.PathLike],
|
283 |
+
is_main_process: bool = True,
|
284 |
+
save_function: Callable = None,
|
285 |
+
safe_serialization: bool = False,
|
286 |
+
variant: Optional[str] = None,
|
287 |
+
):
|
288 |
+
"""
|
289 |
+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
290 |
+
`[`~models.ModelMixin.from_pretrained`]` class method.
|
291 |
+
|
292 |
+
Arguments:
|
293 |
+
save_directory (`str` or `os.PathLike`):
|
294 |
+
Directory to which to save. Will be created if it doesn't exist.
|
295 |
+
is_main_process (`bool`, *optional*, defaults to `True`):
|
296 |
+
Whether the process calling this is the main process or not. Useful when in distributed training like
|
297 |
+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
298 |
+
the main process to avoid race conditions.
|
299 |
+
save_function (`Callable`):
|
300 |
+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
301 |
+
need to replace `torch.save` by another method. Can be configured with the environment variable
|
302 |
+
`DIFFUSERS_SAVE_MODE`.
|
303 |
+
safe_serialization (`bool`, *optional*, defaults to `False`):
|
304 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
305 |
+
variant (`str`, *optional*):
|
306 |
+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
|
307 |
+
"""
|
308 |
+
if safe_serialization and not is_safetensors_available():
|
309 |
+
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
310 |
+
|
311 |
+
if os.path.isfile(save_directory):
|
312 |
+
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
313 |
+
return
|
314 |
+
|
315 |
+
os.makedirs(save_directory, exist_ok=True)
|
316 |
+
|
317 |
+
model_to_save = self
|
318 |
+
|
319 |
+
# Attach architecture to the config
|
320 |
+
# Save the config
|
321 |
+
if is_main_process:
|
322 |
+
model_to_save.save_config(save_directory)
|
323 |
+
|
324 |
+
# Save the model
|
325 |
+
state_dict = model_to_save.state_dict()
|
326 |
+
|
327 |
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
328 |
+
weights_name = _add_variant(weights_name, variant)
|
329 |
+
|
330 |
+
# Save the model
|
331 |
+
if safe_serialization:
|
332 |
+
safetensors.torch.save_file(
|
333 |
+
state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
torch.save(state_dict, os.path.join(save_directory, weights_name))
|
337 |
+
|
338 |
+
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
339 |
+
|
340 |
+
@classmethod
|
341 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
342 |
+
r"""
|
343 |
+
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
344 |
+
|
345 |
+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
346 |
+
the model, you should first set it back in training mode with `model.train()`.
|
347 |
+
|
348 |
+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
349 |
+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
350 |
+
task.
|
351 |
+
|
352 |
+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
353 |
+
weights are discarded.
|
354 |
+
|
355 |
+
Parameters:
|
356 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
357 |
+
Can be either:
|
358 |
+
|
359 |
+
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
360 |
+
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
361 |
+
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
362 |
+
`./my_model_directory/`.
|
363 |
+
|
364 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
365 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
366 |
+
standard cache should not be used.
|
367 |
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
368 |
+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
369 |
+
will be automatically derived from the model's weights.
|
370 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
371 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
372 |
+
cached versions if they exist.
|
373 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
374 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
375 |
+
file exists.
|
376 |
+
proxies (`Dict[str, str]`, *optional*):
|
377 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
378 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
379 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
380 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
381 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
382 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
383 |
+
use_auth_token (`str` or *bool*, *optional*):
|
384 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
385 |
+
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
386 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
387 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
388 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
389 |
+
identifier allowed by git.
|
390 |
+
from_flax (`bool`, *optional*, defaults to `False`):
|
391 |
+
Load the model weights from a Flax checkpoint save file.
|
392 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
393 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
394 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
395 |
+
|
396 |
+
mirror (`str`, *optional*):
|
397 |
+
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
398 |
+
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
399 |
+
Please refer to the mirror site for more information.
|
400 |
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
401 |
+
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
402 |
+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
403 |
+
same device.
|
404 |
+
|
405 |
+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
406 |
+
more information about each option see [designing a device
|
407 |
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
408 |
+
max_memory (`Dict`, *optional*):
|
409 |
+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
410 |
+
GPU and the available CPU RAM if unset.
|
411 |
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
412 |
+
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
413 |
+
offload_state_dict (`bool`, *optional*):
|
414 |
+
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
|
415 |
+
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
|
416 |
+
`True` when there is some disk offload.
|
417 |
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
418 |
+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
|
419 |
+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
|
420 |
+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
|
421 |
+
setting this argument to `True` will raise an error.
|
422 |
+
variant (`str`, *optional*):
|
423 |
+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
|
424 |
+
ignored when using `from_flax`.
|
425 |
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
426 |
+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
|
427 |
+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
|
428 |
+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
|
429 |
+
|
430 |
+
<Tip>
|
431 |
+
|
432 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
433 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
434 |
+
|
435 |
+
</Tip>
|
436 |
+
|
437 |
+
<Tip>
|
438 |
+
|
439 |
+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
440 |
+
this method in a firewalled environment.
|
441 |
+
|
442 |
+
</Tip>
|
443 |
+
|
444 |
+
"""
|
445 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
446 |
+
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
447 |
+
force_download = kwargs.pop("force_download", False)
|
448 |
+
from_flax = kwargs.pop("from_flax", False)
|
449 |
+
resume_download = kwargs.pop("resume_download", False)
|
450 |
+
proxies = kwargs.pop("proxies", None)
|
451 |
+
output_loading_info = kwargs.pop("output_loading_info", False)
|
452 |
+
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
453 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
454 |
+
revision = kwargs.pop("revision", None)
|
455 |
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
456 |
+
subfolder = kwargs.pop("subfolder", None)
|
457 |
+
device_map = kwargs.pop("device_map", None)
|
458 |
+
max_memory = kwargs.pop("max_memory", None)
|
459 |
+
offload_folder = kwargs.pop("offload_folder", None)
|
460 |
+
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
461 |
+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
462 |
+
variant = kwargs.pop("variant", None)
|
463 |
+
use_safetensors = kwargs.pop("use_safetensors", None)
|
464 |
+
|
465 |
+
if use_safetensors and not is_safetensors_available():
|
466 |
+
raise ValueError(
|
467 |
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
468 |
+
)
|
469 |
+
|
470 |
+
allow_pickle = False
|
471 |
+
if use_safetensors is None:
|
472 |
+
use_safetensors = is_safetensors_available()
|
473 |
+
allow_pickle = True
|
474 |
+
|
475 |
+
if low_cpu_mem_usage and not is_accelerate_available():
|
476 |
+
low_cpu_mem_usage = False
|
477 |
+
logger.warning(
|
478 |
+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
479 |
+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
480 |
+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
481 |
+
" install accelerate\n```\n."
|
482 |
+
)
|
483 |
+
|
484 |
+
if device_map is not None and not is_accelerate_available():
|
485 |
+
raise NotImplementedError(
|
486 |
+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
487 |
+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
488 |
+
)
|
489 |
+
|
490 |
+
# Check if we can handle device_map and dispatching the weights
|
491 |
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
492 |
+
raise NotImplementedError(
|
493 |
+
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
494 |
+
" `device_map=None`."
|
495 |
+
)
|
496 |
+
|
497 |
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
498 |
+
raise NotImplementedError(
|
499 |
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
500 |
+
" `low_cpu_mem_usage=False`."
|
501 |
+
)
|
502 |
+
|
503 |
+
if low_cpu_mem_usage is False and device_map is not None:
|
504 |
+
raise ValueError(
|
505 |
+
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
506 |
+
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
507 |
+
)
|
508 |
+
|
509 |
+
# Load config if we don't provide a configuration
|
510 |
+
config_path = pretrained_model_name_or_path
|
511 |
+
|
512 |
+
user_agent = {
|
513 |
+
"diffusers": __version__,
|
514 |
+
"file_type": "model",
|
515 |
+
"framework": "pytorch",
|
516 |
+
}
|
517 |
+
|
518 |
+
# load config
|
519 |
+
config, unused_kwargs, commit_hash = cls.load_config(
|
520 |
+
config_path,
|
521 |
+
cache_dir=cache_dir,
|
522 |
+
return_unused_kwargs=True,
|
523 |
+
return_commit_hash=True,
|
524 |
+
force_download=force_download,
|
525 |
+
resume_download=resume_download,
|
526 |
+
proxies=proxies,
|
527 |
+
local_files_only=local_files_only,
|
528 |
+
use_auth_token=use_auth_token,
|
529 |
+
revision=revision,
|
530 |
+
subfolder=subfolder,
|
531 |
+
device_map=device_map,
|
532 |
+
max_memory=max_memory,
|
533 |
+
offload_folder=offload_folder,
|
534 |
+
offload_state_dict=offload_state_dict,
|
535 |
+
user_agent=user_agent,
|
536 |
+
**kwargs,
|
537 |
+
)
|
538 |
+
|
539 |
+
# load model
|
540 |
+
model_file = None
|
541 |
+
if from_flax:
|
542 |
+
model_file = _get_model_file(
|
543 |
+
pretrained_model_name_or_path,
|
544 |
+
weights_name=FLAX_WEIGHTS_NAME,
|
545 |
+
cache_dir=cache_dir,
|
546 |
+
force_download=force_download,
|
547 |
+
resume_download=resume_download,
|
548 |
+
proxies=proxies,
|
549 |
+
local_files_only=local_files_only,
|
550 |
+
use_auth_token=use_auth_token,
|
551 |
+
revision=revision,
|
552 |
+
subfolder=subfolder,
|
553 |
+
user_agent=user_agent,
|
554 |
+
commit_hash=commit_hash,
|
555 |
+
)
|
556 |
+
model = cls.from_config(config, **unused_kwargs)
|
557 |
+
|
558 |
+
# Convert the weights
|
559 |
+
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
560 |
+
|
561 |
+
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
562 |
+
else:
|
563 |
+
if use_safetensors:
|
564 |
+
try:
|
565 |
+
model_file = _get_model_file(
|
566 |
+
pretrained_model_name_or_path,
|
567 |
+
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
568 |
+
cache_dir=cache_dir,
|
569 |
+
force_download=force_download,
|
570 |
+
resume_download=resume_download,
|
571 |
+
proxies=proxies,
|
572 |
+
local_files_only=local_files_only,
|
573 |
+
use_auth_token=use_auth_token,
|
574 |
+
revision=revision,
|
575 |
+
subfolder=subfolder,
|
576 |
+
user_agent=user_agent,
|
577 |
+
commit_hash=commit_hash,
|
578 |
+
)
|
579 |
+
except IOError as e:
|
580 |
+
if not allow_pickle:
|
581 |
+
raise e
|
582 |
+
pass
|
583 |
+
if model_file is None:
|
584 |
+
model_file = _get_model_file(
|
585 |
+
pretrained_model_name_or_path,
|
586 |
+
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
587 |
+
cache_dir=cache_dir,
|
588 |
+
force_download=force_download,
|
589 |
+
resume_download=resume_download,
|
590 |
+
proxies=proxies,
|
591 |
+
local_files_only=local_files_only,
|
592 |
+
use_auth_token=use_auth_token,
|
593 |
+
revision=revision,
|
594 |
+
subfolder=subfolder,
|
595 |
+
user_agent=user_agent,
|
596 |
+
commit_hash=commit_hash,
|
597 |
+
)
|
598 |
+
|
599 |
+
if low_cpu_mem_usage:
|
600 |
+
# Instantiate model with empty weights
|
601 |
+
with accelerate.init_empty_weights():
|
602 |
+
model = cls.from_config(config, **unused_kwargs)
|
603 |
+
|
604 |
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
605 |
+
if device_map is None:
|
606 |
+
param_device = "cpu"
|
607 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
608 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
609 |
+
# move the params from meta device to cpu
|
610 |
+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
611 |
+
if len(missing_keys) > 0:
|
612 |
+
raise ValueError(
|
613 |
+
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
614 |
+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
615 |
+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
616 |
+
" those weights or else make sure your checkpoint file is correct."
|
617 |
+
)
|
618 |
+
unexpected_keys = []
|
619 |
+
|
620 |
+
empty_state_dict = model.state_dict()
|
621 |
+
for param_name, param in state_dict.items():
|
622 |
+
accepts_dtype = "dtype" in set(
|
623 |
+
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
624 |
+
)
|
625 |
+
|
626 |
+
if param_name not in empty_state_dict:
|
627 |
+
unexpected_keys.append(param_name)
|
628 |
+
continue
|
629 |
+
|
630 |
+
if empty_state_dict[param_name].shape != param.shape:
|
631 |
+
raise ValueError(
|
632 |
+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
633 |
+
)
|
634 |
+
|
635 |
+
if accepts_dtype:
|
636 |
+
set_module_tensor_to_device(
|
637 |
+
model, param_name, param_device, value=param, dtype=torch_dtype
|
638 |
+
)
|
639 |
+
else:
|
640 |
+
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
641 |
+
|
642 |
+
if cls._keys_to_ignore_on_load_unexpected is not None:
|
643 |
+
for pat in cls._keys_to_ignore_on_load_unexpected:
|
644 |
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
645 |
+
|
646 |
+
if len(unexpected_keys) > 0:
|
647 |
+
logger.warn(
|
648 |
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
649 |
+
)
|
650 |
+
|
651 |
+
else: # else let accelerate handle loading and dispatching.
|
652 |
+
# Load weights and dispatch according to the device_map
|
653 |
+
# by default the device_map is None and the weights are loaded on the CPU
|
654 |
+
try:
|
655 |
+
accelerate.load_checkpoint_and_dispatch(
|
656 |
+
model,
|
657 |
+
model_file,
|
658 |
+
device_map,
|
659 |
+
max_memory=max_memory,
|
660 |
+
offload_folder=offload_folder,
|
661 |
+
offload_state_dict=offload_state_dict,
|
662 |
+
dtype=torch_dtype,
|
663 |
+
)
|
664 |
+
except AttributeError as e:
|
665 |
+
# When using accelerate loading, we do not have the ability to load the state
|
666 |
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
667 |
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
668 |
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
669 |
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
670 |
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
671 |
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
672 |
+
# the weights so we don't have to do this again.
|
673 |
+
|
674 |
+
if "'Attention' object has no attribute" in str(e):
|
675 |
+
logger.warn(
|
676 |
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
677 |
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
678 |
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
679 |
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
680 |
+
" please also re-upload it or open a PR on the original repository."
|
681 |
+
)
|
682 |
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
683 |
+
accelerate.load_checkpoint_and_dispatch(
|
684 |
+
model,
|
685 |
+
model_file,
|
686 |
+
device_map,
|
687 |
+
max_memory=max_memory,
|
688 |
+
offload_folder=offload_folder,
|
689 |
+
offload_state_dict=offload_state_dict,
|
690 |
+
dtype=torch_dtype,
|
691 |
+
)
|
692 |
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
693 |
+
else:
|
694 |
+
raise e
|
695 |
+
|
696 |
+
loading_info = {
|
697 |
+
"missing_keys": [],
|
698 |
+
"unexpected_keys": [],
|
699 |
+
"mismatched_keys": [],
|
700 |
+
"error_msgs": [],
|
701 |
+
}
|
702 |
+
else:
|
703 |
+
model = cls.from_config(config, **unused_kwargs)
|
704 |
+
|
705 |
+
state_dict = load_state_dict(model_file, variant=variant)
|
706 |
+
model._convert_deprecated_attention_blocks(state_dict)
|
707 |
+
|
708 |
+
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
709 |
+
model,
|
710 |
+
state_dict,
|
711 |
+
model_file,
|
712 |
+
pretrained_model_name_or_path,
|
713 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
714 |
+
)
|
715 |
+
|
716 |
+
loading_info = {
|
717 |
+
"missing_keys": missing_keys,
|
718 |
+
"unexpected_keys": unexpected_keys,
|
719 |
+
"mismatched_keys": mismatched_keys,
|
720 |
+
"error_msgs": error_msgs,
|
721 |
+
}
|
722 |
+
|
723 |
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
724 |
+
raise ValueError(
|
725 |
+
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
726 |
+
)
|
727 |
+
elif torch_dtype is not None:
|
728 |
+
model = model.to(torch_dtype)
|
729 |
+
|
730 |
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
731 |
+
|
732 |
+
# Set model in evaluation mode to deactivate DropOut modules by default
|
733 |
+
model.eval()
|
734 |
+
if output_loading_info:
|
735 |
+
return model, loading_info
|
736 |
+
|
737 |
+
return model
|
738 |
+
|
739 |
+
@classmethod
|
740 |
+
def _load_pretrained_model(
|
741 |
+
cls,
|
742 |
+
model,
|
743 |
+
state_dict,
|
744 |
+
resolved_archive_file,
|
745 |
+
pretrained_model_name_or_path,
|
746 |
+
ignore_mismatched_sizes=False,
|
747 |
+
):
|
748 |
+
# Retrieve missing & unexpected_keys
|
749 |
+
model_state_dict = model.state_dict()
|
750 |
+
loaded_keys = list(state_dict.keys())
|
751 |
+
|
752 |
+
expected_keys = list(model_state_dict.keys())
|
753 |
+
|
754 |
+
original_loaded_keys = loaded_keys
|
755 |
+
|
756 |
+
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
757 |
+
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
758 |
+
|
759 |
+
# Make sure we are able to load base models as well as derived models (with heads)
|
760 |
+
model_to_load = model
|
761 |
+
|
762 |
+
def _find_mismatched_keys(
|
763 |
+
state_dict,
|
764 |
+
model_state_dict,
|
765 |
+
loaded_keys,
|
766 |
+
ignore_mismatched_sizes,
|
767 |
+
):
|
768 |
+
mismatched_keys = []
|
769 |
+
if ignore_mismatched_sizes:
|
770 |
+
for checkpoint_key in loaded_keys:
|
771 |
+
model_key = checkpoint_key
|
772 |
+
|
773 |
+
if (
|
774 |
+
model_key in model_state_dict
|
775 |
+
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
776 |
+
):
|
777 |
+
mismatched_keys.append(
|
778 |
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
779 |
+
)
|
780 |
+
del state_dict[checkpoint_key]
|
781 |
+
return mismatched_keys
|
782 |
+
|
783 |
+
if state_dict is not None:
|
784 |
+
# Whole checkpoint
|
785 |
+
mismatched_keys = _find_mismatched_keys(
|
786 |
+
state_dict,
|
787 |
+
model_state_dict,
|
788 |
+
original_loaded_keys,
|
789 |
+
ignore_mismatched_sizes,
|
790 |
+
)
|
791 |
+
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
792 |
+
|
793 |
+
if len(error_msgs) > 0:
|
794 |
+
error_msg = "\n\t".join(error_msgs)
|
795 |
+
if "size mismatch" in error_msg:
|
796 |
+
error_msg += (
|
797 |
+
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
798 |
+
)
|
799 |
+
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
800 |
+
|
801 |
+
if len(unexpected_keys) > 0:
|
802 |
+
logger.warning(
|
803 |
+
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
804 |
+
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
805 |
+
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
|
806 |
+
" or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
807 |
+
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
808 |
+
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
|
809 |
+
" identical (initializing a BertForSequenceClassification model from a"
|
810 |
+
" BertForSequenceClassification model)."
|
811 |
+
)
|
812 |
+
else:
|
813 |
+
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
814 |
+
if len(missing_keys) > 0:
|
815 |
+
logger.warning(
|
816 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
817 |
+
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
818 |
+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
819 |
+
)
|
820 |
+
elif len(mismatched_keys) == 0:
|
821 |
+
logger.info(
|
822 |
+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
823 |
+
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
|
824 |
+
f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
|
825 |
+
" without further training."
|
826 |
+
)
|
827 |
+
if len(mismatched_keys) > 0:
|
828 |
+
mismatched_warning = "\n".join(
|
829 |
+
[
|
830 |
+
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
831 |
+
for key, shape1, shape2 in mismatched_keys
|
832 |
+
]
|
833 |
+
)
|
834 |
+
logger.warning(
|
835 |
+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
836 |
+
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
837 |
+
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
|
838 |
+
" able to use it for predictions and inference."
|
839 |
+
)
|
840 |
+
|
841 |
+
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
842 |
+
|
843 |
+
@property
|
844 |
+
def device(self) -> device:
|
845 |
+
"""
|
846 |
+
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
847 |
+
device).
|
848 |
+
"""
|
849 |
+
return get_parameter_device(self)
|
850 |
+
|
851 |
+
@property
|
852 |
+
def dtype(self) -> torch.dtype:
|
853 |
+
"""
|
854 |
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
855 |
+
"""
|
856 |
+
return get_parameter_dtype(self)
|
857 |
+
|
858 |
+
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
859 |
+
"""
|
860 |
+
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
861 |
+
|
862 |
+
Args:
|
863 |
+
only_trainable (`bool`, *optional*, defaults to `False`):
|
864 |
+
Whether or not to return only the number of trainable parameters
|
865 |
+
|
866 |
+
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
867 |
+
Whether or not to return only the number of non-embeddings parameters
|
868 |
+
|
869 |
+
Returns:
|
870 |
+
`int`: The number of parameters.
|
871 |
+
"""
|
872 |
+
|
873 |
+
if exclude_embeddings:
|
874 |
+
embedding_param_names = [
|
875 |
+
f"{name}.weight"
|
876 |
+
for name, module_type in self.named_modules()
|
877 |
+
if isinstance(module_type, torch.nn.Embedding)
|
878 |
+
]
|
879 |
+
non_embedding_parameters = [
|
880 |
+
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
881 |
+
]
|
882 |
+
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
883 |
+
else:
|
884 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
885 |
+
|
886 |
+
def _convert_deprecated_attention_blocks(self, state_dict):
|
887 |
+
deprecated_attention_block_paths = []
|
888 |
+
|
889 |
+
def recursive_find_attn_block(name, module):
|
890 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
891 |
+
deprecated_attention_block_paths.append(name)
|
892 |
+
|
893 |
+
for sub_name, sub_module in module.named_children():
|
894 |
+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
|
895 |
+
recursive_find_attn_block(sub_name, sub_module)
|
896 |
+
|
897 |
+
recursive_find_attn_block("", self)
|
898 |
+
|
899 |
+
# NOTE: we have to check if the deprecated parameters are in the state dict
|
900 |
+
# because it is possible we are loading from a state dict that was already
|
901 |
+
# converted
|
902 |
+
|
903 |
+
for path in deprecated_attention_block_paths:
|
904 |
+
# group_norm path stays the same
|
905 |
+
|
906 |
+
# query -> to_q
|
907 |
+
if f"{path}.query.weight" in state_dict:
|
908 |
+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
|
909 |
+
if f"{path}.query.bias" in state_dict:
|
910 |
+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
|
911 |
+
|
912 |
+
# key -> to_k
|
913 |
+
if f"{path}.key.weight" in state_dict:
|
914 |
+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
|
915 |
+
if f"{path}.key.bias" in state_dict:
|
916 |
+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
|
917 |
+
|
918 |
+
# value -> to_v
|
919 |
+
if f"{path}.value.weight" in state_dict:
|
920 |
+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
|
921 |
+
if f"{path}.value.bias" in state_dict:
|
922 |
+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
|
923 |
+
|
924 |
+
# proj_attn -> to_out.0
|
925 |
+
if f"{path}.proj_attn.weight" in state_dict:
|
926 |
+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
927 |
+
if f"{path}.proj_attn.bias" in state_dict:
|
928 |
+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
929 |
+
|
930 |
+
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
931 |
+
deprecated_attention_block_modules = []
|
932 |
+
|
933 |
+
def recursive_find_attn_block(module):
|
934 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
935 |
+
deprecated_attention_block_modules.append(module)
|
936 |
+
|
937 |
+
for sub_module in module.children():
|
938 |
+
recursive_find_attn_block(sub_module)
|
939 |
+
|
940 |
+
recursive_find_attn_block(self)
|
941 |
+
|
942 |
+
for module in deprecated_attention_block_modules:
|
943 |
+
module.query = module.to_q
|
944 |
+
module.key = module.to_k
|
945 |
+
module.value = module.to_v
|
946 |
+
module.proj_attn = module.to_out[0]
|
947 |
+
|
948 |
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
949 |
+
# that _all_ the weights are loaded into the new attributes and we're not
|
950 |
+
# making an incorrect assumption that this model should be converted when
|
951 |
+
# it really shouldn't be.
|
952 |
+
del module.to_q
|
953 |
+
del module.to_k
|
954 |
+
del module.to_v
|
955 |
+
del module.to_out
|
956 |
+
|
957 |
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
958 |
+
deprecated_attention_block_modules = []
|
959 |
+
|
960 |
+
def recursive_find_attn_block(module):
|
961 |
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
962 |
+
deprecated_attention_block_modules.append(module)
|
963 |
+
|
964 |
+
for sub_module in module.children():
|
965 |
+
recursive_find_attn_block(sub_module)
|
966 |
+
|
967 |
+
recursive_find_attn_block(self)
|
968 |
+
|
969 |
+
for module in deprecated_attention_block_modules:
|
970 |
+
module.to_q = module.query
|
971 |
+
module.to_k = module.key
|
972 |
+
module.to_v = module.value
|
973 |
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
974 |
+
|
975 |
+
del module.query
|
976 |
+
del module.key
|
977 |
+
del module.value
|
978 |
+
del module.proj_attn
|
diffusers/models/prior_transformer.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from ..utils.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from ..utils.outputs import BaseOutput
|
10 |
+
from .attention import BasicTransformerBlock
|
11 |
+
from .embeddings import TimestepEmbedding, Timesteps
|
12 |
+
from .modeling_utils import ModelMixin
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class PriorTransformerOutput(BaseOutput):
|
17 |
+
"""
|
18 |
+
Args:
|
19 |
+
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
20 |
+
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
21 |
+
"""
|
22 |
+
|
23 |
+
predicted_image_embedding: torch.FloatTensor
|
24 |
+
|
25 |
+
|
26 |
+
class PriorTransformer(ModelMixin, ConfigMixin):
|
27 |
+
"""
|
28 |
+
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
|
29 |
+
transformer predicts the image embeddings through a denoising diffusion process.
|
30 |
+
|
31 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
32 |
+
implements for all the models (such as downloading or saving, etc.)
|
33 |
+
|
34 |
+
For more details, see the original paper: https://arxiv.org/abs/2204.06125
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
38 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
39 |
+
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
40 |
+
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
|
41 |
+
image embeddings and text embeddings are both the same dimension.
|
42 |
+
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
|
43 |
+
length of the prompt after it has been tokenized.
|
44 |
+
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
45 |
+
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
|
46 |
+
additional_embeddings`.
|
47 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
48 |
+
|
49 |
+
"""
|
50 |
+
|
51 |
+
@register_to_config
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
num_attention_heads: int = 32,
|
55 |
+
attention_head_dim: int = 64,
|
56 |
+
num_layers: int = 20,
|
57 |
+
embedding_dim: int = 768,
|
58 |
+
num_embeddings=77,
|
59 |
+
additional_embeddings=4,
|
60 |
+
dropout: float = 0.0,
|
61 |
+
):
|
62 |
+
super().__init__()
|
63 |
+
self.num_attention_heads = num_attention_heads
|
64 |
+
self.attention_head_dim = attention_head_dim
|
65 |
+
inner_dim = num_attention_heads * attention_head_dim
|
66 |
+
self.additional_embeddings = additional_embeddings
|
67 |
+
|
68 |
+
self.time_proj = Timesteps(inner_dim, True, 0)
|
69 |
+
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
70 |
+
|
71 |
+
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
72 |
+
|
73 |
+
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
74 |
+
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
75 |
+
|
76 |
+
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
77 |
+
|
78 |
+
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
79 |
+
|
80 |
+
self.transformer_blocks = nn.ModuleList(
|
81 |
+
[
|
82 |
+
BasicTransformerBlock(
|
83 |
+
inner_dim,
|
84 |
+
num_attention_heads,
|
85 |
+
attention_head_dim,
|
86 |
+
dropout=dropout,
|
87 |
+
activation_fn="gelu",
|
88 |
+
attention_bias=True,
|
89 |
+
)
|
90 |
+
for d in range(num_layers)
|
91 |
+
]
|
92 |
+
)
|
93 |
+
|
94 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
95 |
+
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
96 |
+
|
97 |
+
causal_attention_mask = torch.full(
|
98 |
+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
99 |
+
)
|
100 |
+
causal_attention_mask.triu_(1)
|
101 |
+
causal_attention_mask = causal_attention_mask[None, ...]
|
102 |
+
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
103 |
+
|
104 |
+
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
105 |
+
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
106 |
+
|
107 |
+
def forward(
|
108 |
+
self,
|
109 |
+
hidden_states,
|
110 |
+
timestep: Union[torch.Tensor, float, int],
|
111 |
+
proj_embedding: torch.FloatTensor,
|
112 |
+
encoder_hidden_states: torch.FloatTensor,
|
113 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
114 |
+
return_dict: bool = True,
|
115 |
+
):
|
116 |
+
"""
|
117 |
+
Args:
|
118 |
+
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
119 |
+
x_t, the currently predicted image embeddings.
|
120 |
+
timestep (`torch.long`):
|
121 |
+
Current denoising step.
|
122 |
+
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
123 |
+
Projected embedding vector the denoising process is conditioned on.
|
124 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
125 |
+
Hidden states of the text embeddings the denoising process is conditioned on.
|
126 |
+
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
127 |
+
Text mask for the text embeddings.
|
128 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
129 |
+
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
130 |
+
tuple.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
134 |
+
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
135 |
+
returning a tuple, the first element is the sample tensor.
|
136 |
+
"""
|
137 |
+
batch_size = hidden_states.shape[0]
|
138 |
+
|
139 |
+
timesteps = timestep
|
140 |
+
if not torch.is_tensor(timesteps):
|
141 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
142 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
143 |
+
timesteps = timesteps[None].to(hidden_states.device)
|
144 |
+
|
145 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
146 |
+
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
147 |
+
|
148 |
+
timesteps_projected = self.time_proj(timesteps)
|
149 |
+
|
150 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
151 |
+
# but time_embedding might be fp16, so we need to cast here.
|
152 |
+
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
153 |
+
time_embeddings = self.time_embedding(timesteps_projected)
|
154 |
+
|
155 |
+
proj_embeddings = self.embedding_proj(proj_embedding)
|
156 |
+
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
157 |
+
hidden_states = self.proj_in(hidden_states)
|
158 |
+
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
159 |
+
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
160 |
+
|
161 |
+
hidden_states = torch.cat(
|
162 |
+
[
|
163 |
+
encoder_hidden_states,
|
164 |
+
proj_embeddings[:, None, :],
|
165 |
+
time_embeddings[:, None, :],
|
166 |
+
hidden_states[:, None, :],
|
167 |
+
prd_embedding,
|
168 |
+
],
|
169 |
+
dim=1,
|
170 |
+
)
|
171 |
+
|
172 |
+
hidden_states = hidden_states + positional_embeddings
|
173 |
+
|
174 |
+
if attention_mask is not None:
|
175 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
176 |
+
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
177 |
+
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
178 |
+
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
179 |
+
|
180 |
+
for block in self.transformer_blocks:
|
181 |
+
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
182 |
+
|
183 |
+
hidden_states = self.norm_out(hidden_states)
|
184 |
+
hidden_states = hidden_states[:, -1]
|
185 |
+
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
186 |
+
|
187 |
+
if not return_dict:
|
188 |
+
return (predicted_image_embedding,)
|
189 |
+
|
190 |
+
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
191 |
+
|
192 |
+
def post_process_latents(self, prior_latents):
|
193 |
+
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
194 |
+
return prior_latents
|
diffusers/models/resnet.py
ADDED
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from functools import partial
|
17 |
+
from typing import Optional
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
|
23 |
+
from .attention import AdaGroupNorm
|
24 |
+
|
25 |
+
|
26 |
+
class Upsample1D(nn.Module):
|
27 |
+
"""
|
28 |
+
An upsampling layer with an optional convolution.
|
29 |
+
|
30 |
+
Parameters:
|
31 |
+
channels: channels in the inputs and outputs.
|
32 |
+
use_conv: a bool determining if a convolution is applied.
|
33 |
+
use_conv_transpose:
|
34 |
+
out_channels:
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
38 |
+
super().__init__()
|
39 |
+
self.channels = channels
|
40 |
+
self.out_channels = out_channels or channels
|
41 |
+
self.use_conv = use_conv
|
42 |
+
self.use_conv_transpose = use_conv_transpose
|
43 |
+
self.name = name
|
44 |
+
|
45 |
+
self.conv = None
|
46 |
+
if use_conv_transpose:
|
47 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
48 |
+
elif use_conv:
|
49 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
assert x.shape[1] == self.channels
|
53 |
+
if self.use_conv_transpose:
|
54 |
+
return self.conv(x)
|
55 |
+
|
56 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
57 |
+
|
58 |
+
if self.use_conv:
|
59 |
+
x = self.conv(x)
|
60 |
+
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Downsample1D(nn.Module):
|
65 |
+
"""
|
66 |
+
A downsampling layer with an optional convolution.
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
channels: channels in the inputs and outputs.
|
70 |
+
use_conv: a bool determining if a convolution is applied.
|
71 |
+
out_channels:
|
72 |
+
padding:
|
73 |
+
"""
|
74 |
+
|
75 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
76 |
+
super().__init__()
|
77 |
+
self.channels = channels
|
78 |
+
self.out_channels = out_channels or channels
|
79 |
+
self.use_conv = use_conv
|
80 |
+
self.padding = padding
|
81 |
+
stride = 2
|
82 |
+
self.name = name
|
83 |
+
|
84 |
+
if use_conv:
|
85 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
86 |
+
else:
|
87 |
+
assert self.channels == self.out_channels
|
88 |
+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
assert x.shape[1] == self.channels
|
92 |
+
return self.conv(x)
|
93 |
+
|
94 |
+
|
95 |
+
class Upsample2D(nn.Module):
|
96 |
+
"""
|
97 |
+
An upsampling layer with an optional convolution.
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
channels: channels in the inputs and outputs.
|
101 |
+
use_conv: a bool determining if a convolution is applied.
|
102 |
+
use_conv_transpose:
|
103 |
+
out_channels:
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
107 |
+
super().__init__()
|
108 |
+
self.channels = channels
|
109 |
+
self.out_channels = out_channels or channels
|
110 |
+
self.use_conv = use_conv
|
111 |
+
self.use_conv_transpose = use_conv_transpose
|
112 |
+
self.name = name
|
113 |
+
|
114 |
+
conv = None
|
115 |
+
if use_conv_transpose:
|
116 |
+
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
117 |
+
elif use_conv:
|
118 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
|
119 |
+
|
120 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
121 |
+
if name == "conv":
|
122 |
+
self.conv = conv
|
123 |
+
else:
|
124 |
+
self.Conv2d_0 = conv
|
125 |
+
|
126 |
+
def forward(self, hidden_states, output_size=None):
|
127 |
+
assert hidden_states.shape[1] == self.channels
|
128 |
+
|
129 |
+
if self.use_conv_transpose:
|
130 |
+
return self.conv(hidden_states)
|
131 |
+
|
132 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
133 |
+
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
134 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
135 |
+
dtype = hidden_states.dtype
|
136 |
+
if dtype == torch.bfloat16:
|
137 |
+
hidden_states = hidden_states.to(torch.float32)
|
138 |
+
|
139 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
140 |
+
if hidden_states.shape[0] >= 64:
|
141 |
+
hidden_states = hidden_states.contiguous()
|
142 |
+
|
143 |
+
# if `output_size` is passed we force the interpolation output
|
144 |
+
# size and do not make use of `scale_factor=2`
|
145 |
+
if output_size is None:
|
146 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
147 |
+
else:
|
148 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
149 |
+
|
150 |
+
# If the input is bfloat16, we cast back to bfloat16
|
151 |
+
if dtype == torch.bfloat16:
|
152 |
+
hidden_states = hidden_states.to(dtype)
|
153 |
+
|
154 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
155 |
+
if self.use_conv:
|
156 |
+
if self.name == "conv":
|
157 |
+
hidden_states = self.conv(hidden_states)
|
158 |
+
else:
|
159 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
160 |
+
|
161 |
+
return hidden_states
|
162 |
+
|
163 |
+
|
164 |
+
class Downsample2D(nn.Module):
|
165 |
+
"""
|
166 |
+
A downsampling layer with an optional convolution.
|
167 |
+
|
168 |
+
Parameters:
|
169 |
+
channels: channels in the inputs and outputs.
|
170 |
+
use_conv: a bool determining if a convolution is applied.
|
171 |
+
out_channels:
|
172 |
+
padding:
|
173 |
+
"""
|
174 |
+
|
175 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
176 |
+
super().__init__()
|
177 |
+
self.channels = channels
|
178 |
+
self.out_channels = out_channels or channels
|
179 |
+
self.use_conv = use_conv
|
180 |
+
self.padding = padding
|
181 |
+
stride = 2
|
182 |
+
self.name = name
|
183 |
+
|
184 |
+
if use_conv:
|
185 |
+
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
186 |
+
else:
|
187 |
+
assert self.channels == self.out_channels
|
188 |
+
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
189 |
+
|
190 |
+
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
191 |
+
if name == "conv":
|
192 |
+
self.Conv2d_0 = conv
|
193 |
+
self.conv = conv
|
194 |
+
elif name == "Conv2d_0":
|
195 |
+
self.conv = conv
|
196 |
+
else:
|
197 |
+
self.conv = conv
|
198 |
+
|
199 |
+
def forward(self, hidden_states):
|
200 |
+
assert hidden_states.shape[1] == self.channels
|
201 |
+
if self.use_conv and self.padding == 0:
|
202 |
+
pad = (0, 1, 0, 1)
|
203 |
+
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
204 |
+
|
205 |
+
assert hidden_states.shape[1] == self.channels
|
206 |
+
hidden_states = self.conv(hidden_states)
|
207 |
+
|
208 |
+
return hidden_states
|
209 |
+
|
210 |
+
|
211 |
+
class FirUpsample2D(nn.Module):
|
212 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
213 |
+
super().__init__()
|
214 |
+
out_channels = out_channels if out_channels else channels
|
215 |
+
if use_conv:
|
216 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
217 |
+
self.use_conv = use_conv
|
218 |
+
self.fir_kernel = fir_kernel
|
219 |
+
self.out_channels = out_channels
|
220 |
+
|
221 |
+
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
222 |
+
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
223 |
+
|
224 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
225 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
226 |
+
arbitrary order.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
230 |
+
weight: Weight tensor of the shape `[filterH, filterW, inChannels,
|
231 |
+
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
|
232 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
233 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
234 |
+
factor: Integer upsampling factor (default: 2).
|
235 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
236 |
+
|
237 |
+
Returns:
|
238 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
239 |
+
datatype as `hidden_states`.
|
240 |
+
"""
|
241 |
+
|
242 |
+
assert isinstance(factor, int) and factor >= 1
|
243 |
+
|
244 |
+
# Setup filter kernel.
|
245 |
+
if kernel is None:
|
246 |
+
kernel = [1] * factor
|
247 |
+
|
248 |
+
# setup kernel
|
249 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
250 |
+
if kernel.ndim == 1:
|
251 |
+
kernel = torch.outer(kernel, kernel)
|
252 |
+
kernel /= torch.sum(kernel)
|
253 |
+
|
254 |
+
kernel = kernel * (gain * (factor**2))
|
255 |
+
|
256 |
+
if self.use_conv:
|
257 |
+
convH = weight.shape[2]
|
258 |
+
convW = weight.shape[3]
|
259 |
+
inC = weight.shape[1]
|
260 |
+
|
261 |
+
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
262 |
+
|
263 |
+
stride = (factor, factor)
|
264 |
+
# Determine data dimensions.
|
265 |
+
output_shape = (
|
266 |
+
(hidden_states.shape[2] - 1) * factor + convH,
|
267 |
+
(hidden_states.shape[3] - 1) * factor + convW,
|
268 |
+
)
|
269 |
+
output_padding = (
|
270 |
+
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
271 |
+
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
272 |
+
)
|
273 |
+
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
274 |
+
num_groups = hidden_states.shape[1] // inC
|
275 |
+
|
276 |
+
# Transpose weights.
|
277 |
+
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
278 |
+
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
279 |
+
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
280 |
+
|
281 |
+
inverse_conv = F.conv_transpose2d(
|
282 |
+
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
|
283 |
+
)
|
284 |
+
|
285 |
+
output = upfirdn2d_native(
|
286 |
+
inverse_conv,
|
287 |
+
torch.tensor(kernel, device=inverse_conv.device),
|
288 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
pad_value = kernel.shape[0] - factor
|
292 |
+
output = upfirdn2d_native(
|
293 |
+
hidden_states,
|
294 |
+
torch.tensor(kernel, device=hidden_states.device),
|
295 |
+
up=factor,
|
296 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
297 |
+
)
|
298 |
+
|
299 |
+
return output
|
300 |
+
|
301 |
+
def forward(self, hidden_states):
|
302 |
+
if self.use_conv:
|
303 |
+
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
304 |
+
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
305 |
+
else:
|
306 |
+
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
307 |
+
|
308 |
+
return height
|
309 |
+
|
310 |
+
|
311 |
+
class FirDownsample2D(nn.Module):
|
312 |
+
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
313 |
+
super().__init__()
|
314 |
+
out_channels = out_channels if out_channels else channels
|
315 |
+
if use_conv:
|
316 |
+
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
317 |
+
self.fir_kernel = fir_kernel
|
318 |
+
self.use_conv = use_conv
|
319 |
+
self.out_channels = out_channels
|
320 |
+
|
321 |
+
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
322 |
+
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
323 |
+
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
324 |
+
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
325 |
+
arbitrary order.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
329 |
+
weight:
|
330 |
+
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
331 |
+
performed by `inChannels = x.shape[0] // numGroups`.
|
332 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
|
333 |
+
factor`, which corresponds to average pooling.
|
334 |
+
factor: Integer downsampling factor (default: 2).
|
335 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
336 |
+
|
337 |
+
Returns:
|
338 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
|
339 |
+
same datatype as `x`.
|
340 |
+
"""
|
341 |
+
|
342 |
+
assert isinstance(factor, int) and factor >= 1
|
343 |
+
if kernel is None:
|
344 |
+
kernel = [1] * factor
|
345 |
+
|
346 |
+
# setup kernel
|
347 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
348 |
+
if kernel.ndim == 1:
|
349 |
+
kernel = torch.outer(kernel, kernel)
|
350 |
+
kernel /= torch.sum(kernel)
|
351 |
+
|
352 |
+
kernel = kernel * gain
|
353 |
+
|
354 |
+
if self.use_conv:
|
355 |
+
_, _, convH, convW = weight.shape
|
356 |
+
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
357 |
+
stride_value = [factor, factor]
|
358 |
+
upfirdn_input = upfirdn2d_native(
|
359 |
+
hidden_states,
|
360 |
+
torch.tensor(kernel, device=hidden_states.device),
|
361 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
362 |
+
)
|
363 |
+
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
364 |
+
else:
|
365 |
+
pad_value = kernel.shape[0] - factor
|
366 |
+
output = upfirdn2d_native(
|
367 |
+
hidden_states,
|
368 |
+
torch.tensor(kernel, device=hidden_states.device),
|
369 |
+
down=factor,
|
370 |
+
pad=((pad_value + 1) // 2, pad_value // 2),
|
371 |
+
)
|
372 |
+
|
373 |
+
return output
|
374 |
+
|
375 |
+
def forward(self, hidden_states):
|
376 |
+
if self.use_conv:
|
377 |
+
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
378 |
+
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
379 |
+
else:
|
380 |
+
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
381 |
+
|
382 |
+
return hidden_states
|
383 |
+
|
384 |
+
|
385 |
+
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
386 |
+
class KDownsample2D(nn.Module):
|
387 |
+
def __init__(self, pad_mode="reflect"):
|
388 |
+
super().__init__()
|
389 |
+
self.pad_mode = pad_mode
|
390 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
391 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
392 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
393 |
+
|
394 |
+
def forward(self, x):
|
395 |
+
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
396 |
+
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
397 |
+
indices = torch.arange(x.shape[1], device=x.device)
|
398 |
+
weight[indices, indices] = self.kernel.to(weight)
|
399 |
+
return F.conv2d(x, weight, stride=2)
|
400 |
+
|
401 |
+
|
402 |
+
class KUpsample2D(nn.Module):
|
403 |
+
def __init__(self, pad_mode="reflect"):
|
404 |
+
super().__init__()
|
405 |
+
self.pad_mode = pad_mode
|
406 |
+
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
407 |
+
self.pad = kernel_1d.shape[1] // 2 - 1
|
408 |
+
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
412 |
+
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
413 |
+
indices = torch.arange(x.shape[1], device=x.device)
|
414 |
+
weight[indices, indices] = self.kernel.to(weight)
|
415 |
+
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
416 |
+
|
417 |
+
|
418 |
+
class ResnetBlock2D(nn.Module):
|
419 |
+
r"""
|
420 |
+
A Resnet block.
|
421 |
+
|
422 |
+
Parameters:
|
423 |
+
in_channels (`int`): The number of channels in the input.
|
424 |
+
out_channels (`int`, *optional*, default to be `None`):
|
425 |
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
426 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
427 |
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
428 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
429 |
+
groups_out (`int`, *optional*, default to None):
|
430 |
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
431 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
432 |
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
433 |
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
434 |
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
435 |
+
"ada_group" for a stronger conditioning with scale and shift.
|
436 |
+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
437 |
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
438 |
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
439 |
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
440 |
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
441 |
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
442 |
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
443 |
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
444 |
+
`conv_shortcut` output.
|
445 |
+
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
446 |
+
If None, same as `out_channels`.
|
447 |
+
"""
|
448 |
+
|
449 |
+
def __init__(
|
450 |
+
self,
|
451 |
+
*,
|
452 |
+
in_channels,
|
453 |
+
out_channels=None,
|
454 |
+
conv_shortcut=False,
|
455 |
+
dropout=0.0,
|
456 |
+
temb_channels=512,
|
457 |
+
groups=32,
|
458 |
+
groups_out=None,
|
459 |
+
pre_norm=True,
|
460 |
+
eps=1e-6,
|
461 |
+
non_linearity="swish",
|
462 |
+
time_embedding_norm="default", # default, scale_shift, ada_group
|
463 |
+
kernel=None,
|
464 |
+
output_scale_factor=1.0,
|
465 |
+
use_in_shortcut=None,
|
466 |
+
up=False,
|
467 |
+
down=False,
|
468 |
+
conv_shortcut_bias: bool = True,
|
469 |
+
conv_2d_out_channels: Optional[int] = None,
|
470 |
+
):
|
471 |
+
super().__init__()
|
472 |
+
self.pre_norm = pre_norm
|
473 |
+
self.pre_norm = True
|
474 |
+
self.in_channels = in_channels
|
475 |
+
out_channels = in_channels if out_channels is None else out_channels
|
476 |
+
self.out_channels = out_channels
|
477 |
+
self.use_conv_shortcut = conv_shortcut
|
478 |
+
self.up = up
|
479 |
+
self.down = down
|
480 |
+
self.output_scale_factor = output_scale_factor
|
481 |
+
self.time_embedding_norm = time_embedding_norm
|
482 |
+
|
483 |
+
if groups_out is None:
|
484 |
+
groups_out = groups
|
485 |
+
|
486 |
+
if self.time_embedding_norm == "ada_group":
|
487 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
488 |
+
else:
|
489 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
490 |
+
|
491 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
492 |
+
|
493 |
+
if temb_channels is not None:
|
494 |
+
if self.time_embedding_norm == "default":
|
495 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
496 |
+
elif self.time_embedding_norm == "scale_shift":
|
497 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
|
498 |
+
elif self.time_embedding_norm == "ada_group":
|
499 |
+
self.time_emb_proj = None
|
500 |
+
else:
|
501 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
502 |
+
else:
|
503 |
+
self.time_emb_proj = None
|
504 |
+
|
505 |
+
if self.time_embedding_norm == "ada_group":
|
506 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
507 |
+
else:
|
508 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
509 |
+
|
510 |
+
self.dropout = torch.nn.Dropout(dropout)
|
511 |
+
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
512 |
+
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
513 |
+
|
514 |
+
if non_linearity == "swish":
|
515 |
+
self.nonlinearity = lambda x: F.silu(x)
|
516 |
+
elif non_linearity == "mish":
|
517 |
+
self.nonlinearity = nn.Mish()
|
518 |
+
elif non_linearity == "silu":
|
519 |
+
self.nonlinearity = nn.SiLU()
|
520 |
+
elif non_linearity == "gelu":
|
521 |
+
self.nonlinearity = nn.GELU()
|
522 |
+
|
523 |
+
self.upsample = self.downsample = None
|
524 |
+
if self.up:
|
525 |
+
if kernel == "fir":
|
526 |
+
fir_kernel = (1, 3, 3, 1)
|
527 |
+
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
|
528 |
+
elif kernel == "sde_vp":
|
529 |
+
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
|
530 |
+
else:
|
531 |
+
self.upsample = Upsample2D(in_channels, use_conv=False)
|
532 |
+
elif self.down:
|
533 |
+
if kernel == "fir":
|
534 |
+
fir_kernel = (1, 3, 3, 1)
|
535 |
+
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
|
536 |
+
elif kernel == "sde_vp":
|
537 |
+
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
|
538 |
+
else:
|
539 |
+
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
|
540 |
+
|
541 |
+
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
|
542 |
+
|
543 |
+
self.conv_shortcut = None
|
544 |
+
if self.use_in_shortcut:
|
545 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
546 |
+
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
547 |
+
)
|
548 |
+
|
549 |
+
def forward(self, input_tensor, temb):
|
550 |
+
hidden_states = input_tensor
|
551 |
+
|
552 |
+
if self.time_embedding_norm == "ada_group":
|
553 |
+
hidden_states = self.norm1(hidden_states, temb)
|
554 |
+
else:
|
555 |
+
hidden_states = self.norm1(hidden_states)
|
556 |
+
|
557 |
+
hidden_states = self.nonlinearity(hidden_states)
|
558 |
+
|
559 |
+
if self.upsample is not None:
|
560 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
561 |
+
if hidden_states.shape[0] >= 64:
|
562 |
+
input_tensor = input_tensor.contiguous()
|
563 |
+
hidden_states = hidden_states.contiguous()
|
564 |
+
input_tensor = self.upsample(input_tensor)
|
565 |
+
hidden_states = self.upsample(hidden_states)
|
566 |
+
elif self.downsample is not None:
|
567 |
+
input_tensor = self.downsample(input_tensor)
|
568 |
+
hidden_states = self.downsample(hidden_states)
|
569 |
+
|
570 |
+
hidden_states = self.conv1(hidden_states)
|
571 |
+
|
572 |
+
if self.time_emb_proj is not None:
|
573 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
574 |
+
|
575 |
+
if temb is not None and self.time_embedding_norm == "default":
|
576 |
+
hidden_states = hidden_states + temb
|
577 |
+
|
578 |
+
if self.time_embedding_norm == "ada_group":
|
579 |
+
hidden_states = self.norm2(hidden_states, temb)
|
580 |
+
else:
|
581 |
+
hidden_states = self.norm2(hidden_states)
|
582 |
+
|
583 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
584 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
585 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
586 |
+
|
587 |
+
hidden_states = self.nonlinearity(hidden_states)
|
588 |
+
|
589 |
+
hidden_states = self.dropout(hidden_states)
|
590 |
+
hidden_states = self.conv2(hidden_states)
|
591 |
+
|
592 |
+
if self.conv_shortcut is not None:
|
593 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
594 |
+
|
595 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
596 |
+
|
597 |
+
return output_tensor
|
598 |
+
|
599 |
+
|
600 |
+
class Mish(torch.nn.Module):
|
601 |
+
def forward(self, hidden_states):
|
602 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
603 |
+
|
604 |
+
|
605 |
+
# unet_rl.py
|
606 |
+
def rearrange_dims(tensor):
|
607 |
+
if len(tensor.shape) == 2:
|
608 |
+
return tensor[:, :, None]
|
609 |
+
if len(tensor.shape) == 3:
|
610 |
+
return tensor[:, :, None, :]
|
611 |
+
elif len(tensor.shape) == 4:
|
612 |
+
return tensor[:, :, 0, :]
|
613 |
+
else:
|
614 |
+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
|
615 |
+
|
616 |
+
|
617 |
+
class Conv1dBlock(nn.Module):
|
618 |
+
"""
|
619 |
+
Conv1d --> GroupNorm --> Mish
|
620 |
+
"""
|
621 |
+
|
622 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
623 |
+
super().__init__()
|
624 |
+
|
625 |
+
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
626 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
627 |
+
self.mish = nn.Mish()
|
628 |
+
|
629 |
+
def forward(self, x):
|
630 |
+
x = self.conv1d(x)
|
631 |
+
x = rearrange_dims(x)
|
632 |
+
x = self.group_norm(x)
|
633 |
+
x = rearrange_dims(x)
|
634 |
+
x = self.mish(x)
|
635 |
+
return x
|
636 |
+
|
637 |
+
|
638 |
+
# unet_rl.py
|
639 |
+
class ResidualTemporalBlock1D(nn.Module):
|
640 |
+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
641 |
+
super().__init__()
|
642 |
+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
643 |
+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
644 |
+
|
645 |
+
self.time_emb_act = nn.Mish()
|
646 |
+
self.time_emb = nn.Linear(embed_dim, out_channels)
|
647 |
+
|
648 |
+
self.residual_conv = (
|
649 |
+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
650 |
+
)
|
651 |
+
|
652 |
+
def forward(self, x, t):
|
653 |
+
"""
|
654 |
+
Args:
|
655 |
+
x : [ batch_size x inp_channels x horizon ]
|
656 |
+
t : [ batch_size x embed_dim ]
|
657 |
+
|
658 |
+
returns:
|
659 |
+
out : [ batch_size x out_channels x horizon ]
|
660 |
+
"""
|
661 |
+
t = self.time_emb_act(t)
|
662 |
+
t = self.time_emb(t)
|
663 |
+
out = self.conv_in(x) + rearrange_dims(t)
|
664 |
+
out = self.conv_out(out)
|
665 |
+
return out + self.residual_conv(x)
|
666 |
+
|
667 |
+
|
668 |
+
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
669 |
+
r"""Upsample2D a batch of 2D images with the given filter.
|
670 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
671 |
+
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
672 |
+
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
673 |
+
a: multiple of the upsampling factor.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
677 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
678 |
+
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
|
679 |
+
factor: Integer upsampling factor (default: 2).
|
680 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
681 |
+
|
682 |
+
Returns:
|
683 |
+
output: Tensor of the shape `[N, C, H * factor, W * factor]`
|
684 |
+
"""
|
685 |
+
assert isinstance(factor, int) and factor >= 1
|
686 |
+
if kernel is None:
|
687 |
+
kernel = [1] * factor
|
688 |
+
|
689 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
690 |
+
if kernel.ndim == 1:
|
691 |
+
kernel = torch.outer(kernel, kernel)
|
692 |
+
kernel /= torch.sum(kernel)
|
693 |
+
|
694 |
+
kernel = kernel * (gain * (factor**2))
|
695 |
+
pad_value = kernel.shape[0] - factor
|
696 |
+
output = upfirdn2d_native(
|
697 |
+
hidden_states,
|
698 |
+
kernel.to(device=hidden_states.device),
|
699 |
+
up=factor,
|
700 |
+
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
701 |
+
)
|
702 |
+
return output
|
703 |
+
|
704 |
+
|
705 |
+
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
706 |
+
r"""Downsample2D a batch of 2D images with the given filter.
|
707 |
+
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
708 |
+
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
709 |
+
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
710 |
+
shape is a multiple of the downsampling factor.
|
711 |
+
|
712 |
+
Args:
|
713 |
+
hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
714 |
+
kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
|
715 |
+
(separable). The default is `[1] * factor`, which corresponds to average pooling.
|
716 |
+
factor: Integer downsampling factor (default: 2).
|
717 |
+
gain: Scaling factor for signal magnitude (default: 1.0).
|
718 |
+
|
719 |
+
Returns:
|
720 |
+
output: Tensor of the shape `[N, C, H // factor, W // factor]`
|
721 |
+
"""
|
722 |
+
|
723 |
+
assert isinstance(factor, int) and factor >= 1
|
724 |
+
if kernel is None:
|
725 |
+
kernel = [1] * factor
|
726 |
+
|
727 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
728 |
+
if kernel.ndim == 1:
|
729 |
+
kernel = torch.outer(kernel, kernel)
|
730 |
+
kernel /= torch.sum(kernel)
|
731 |
+
|
732 |
+
kernel = kernel * gain
|
733 |
+
pad_value = kernel.shape[0] - factor
|
734 |
+
output = upfirdn2d_native(
|
735 |
+
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
|
736 |
+
)
|
737 |
+
return output
|
738 |
+
|
739 |
+
|
740 |
+
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
741 |
+
up_x = up_y = up
|
742 |
+
down_x = down_y = down
|
743 |
+
pad_x0 = pad_y0 = pad[0]
|
744 |
+
pad_x1 = pad_y1 = pad[1]
|
745 |
+
|
746 |
+
_, channel, in_h, in_w = tensor.shape
|
747 |
+
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
748 |
+
|
749 |
+
_, in_h, in_w, minor = tensor.shape
|
750 |
+
kernel_h, kernel_w = kernel.shape
|
751 |
+
|
752 |
+
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
753 |
+
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
754 |
+
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
755 |
+
|
756 |
+
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
757 |
+
out = out.to(tensor.device) # Move back to mps if necessary
|
758 |
+
out = out[
|
759 |
+
:,
|
760 |
+
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
761 |
+
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
762 |
+
:,
|
763 |
+
]
|
764 |
+
|
765 |
+
out = out.permute(0, 3, 1, 2)
|
766 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
767 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
768 |
+
out = F.conv2d(out, w)
|
769 |
+
out = out.reshape(
|
770 |
+
-1,
|
771 |
+
minor,
|
772 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
773 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
774 |
+
)
|
775 |
+
out = out.permute(0, 2, 3, 1)
|
776 |
+
out = out[:, ::down_y, ::down_x, :]
|
777 |
+
|
778 |
+
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
779 |
+
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
780 |
+
|
781 |
+
return out.view(-1, channel, out_h, out_w)
|
782 |
+
|
783 |
+
|
784 |
+
class TemporalConvLayer(nn.Module):
|
785 |
+
"""
|
786 |
+
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
787 |
+
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
788 |
+
"""
|
789 |
+
|
790 |
+
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
791 |
+
super().__init__()
|
792 |
+
out_dim = out_dim or in_dim
|
793 |
+
self.in_dim = in_dim
|
794 |
+
self.out_dim = out_dim
|
795 |
+
|
796 |
+
# conv layers
|
797 |
+
self.conv1 = nn.Sequential(
|
798 |
+
nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
|
799 |
+
)
|
800 |
+
self.conv2 = nn.Sequential(
|
801 |
+
nn.GroupNorm(32, out_dim),
|
802 |
+
nn.SiLU(),
|
803 |
+
nn.Dropout(dropout),
|
804 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
805 |
+
)
|
806 |
+
self.conv3 = nn.Sequential(
|
807 |
+
nn.GroupNorm(32, out_dim),
|
808 |
+
nn.SiLU(),
|
809 |
+
nn.Dropout(dropout),
|
810 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
811 |
+
)
|
812 |
+
self.conv4 = nn.Sequential(
|
813 |
+
nn.GroupNorm(32, out_dim),
|
814 |
+
nn.SiLU(),
|
815 |
+
nn.Dropout(dropout),
|
816 |
+
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
|
817 |
+
)
|
818 |
+
|
819 |
+
# zero out the last layer params,so the conv block is identity
|
820 |
+
nn.init.zeros_(self.conv4[-1].weight)
|
821 |
+
nn.init.zeros_(self.conv4[-1].bias)
|
822 |
+
|
823 |
+
def forward(self, hidden_states, num_frames=1):
|
824 |
+
hidden_states = (
|
825 |
+
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
826 |
+
)
|
827 |
+
|
828 |
+
identity = hidden_states
|
829 |
+
hidden_states = self.conv1(hidden_states)
|
830 |
+
hidden_states = self.conv2(hidden_states)
|
831 |
+
hidden_states = self.conv3(hidden_states)
|
832 |
+
hidden_states = self.conv4(hidden_states)
|
833 |
+
|
834 |
+
hidden_states = identity + hidden_states
|
835 |
+
|
836 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
|
837 |
+
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
|
838 |
+
)
|
839 |
+
return hidden_states
|
diffusers/models/transformer_2d.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, Optional
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
|
22 |
+
from ..utils.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..utils.outputs import BaseOutput
|
24 |
+
from ..utils.deprecation_utils import deprecate
|
25 |
+
from ..models.embeddings import ImagePositionalEmbeddings
|
26 |
+
from .attention import BasicTransformerBlock
|
27 |
+
from .embeddings import PatchEmbed
|
28 |
+
from .modeling_utils import ModelMixin
|
29 |
+
|
30 |
+
|
31 |
+
@dataclass
|
32 |
+
class Transformer2DModelOutput(BaseOutput):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or
|
36 |
+
`(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
37 |
+
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
38 |
+
for the unnoised latent pixels.
|
39 |
+
"""
|
40 |
+
|
41 |
+
sample: torch.FloatTensor
|
42 |
+
|
43 |
+
|
44 |
+
class Transformer2DModel(ModelMixin, ConfigMixin):
|
45 |
+
"""
|
46 |
+
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
47 |
+
embeddings) inputs.
|
48 |
+
|
49 |
+
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
50 |
+
transformer action. Finally, reshape to image.
|
51 |
+
|
52 |
+
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
53 |
+
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
54 |
+
classes of unnoised image.
|
55 |
+
|
56 |
+
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
57 |
+
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
58 |
+
|
59 |
+
Parameters:
|
60 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
61 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
62 |
+
in_channels (`int`, *optional*):
|
63 |
+
Pass if the input is continuous. The number of channels in the input and output.
|
64 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
65 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
66 |
+
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
67 |
+
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
68 |
+
Note that this is fixed at training time as it is used for learning a number of position embeddings.
|
69 |
+
See `ImagePositionalEmbeddings`.
|
70 |
+
num_vector_embeds (`int`, *optional*):
|
71 |
+
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
72 |
+
Includes the class for the masked latent pixel.
|
73 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
74 |
+
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
75 |
+
The number of diffusion steps used during training. Note that this is fixed at training time as it is
|
76 |
+
used to learn a number of embeddings that are added to the hidden states. During inference, you can
|
77 |
+
denoise for up to but not more than steps than `num_embeds_ada_norm`.
|
78 |
+
attention_bias (`bool`, *optional*):
|
79 |
+
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
80 |
+
"""
|
81 |
+
|
82 |
+
@register_to_config
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
num_attention_heads: int = 16,
|
86 |
+
attention_head_dim: int = 88,
|
87 |
+
in_channels: Optional[int] = None,
|
88 |
+
out_channels: Optional[int] = None,
|
89 |
+
num_layers: int = 1,
|
90 |
+
dropout: float = 0.0,
|
91 |
+
norm_num_groups: int = 32,
|
92 |
+
cross_attention_dim: Optional[int] = None,
|
93 |
+
attention_bias: bool = False,
|
94 |
+
sample_size: Optional[int] = None,
|
95 |
+
num_vector_embeds: Optional[int] = None,
|
96 |
+
patch_size: Optional[int] = None,
|
97 |
+
activation_fn: str = "geglu",
|
98 |
+
num_embeds_ada_norm: Optional[int] = None,
|
99 |
+
use_linear_projection: bool = False,
|
100 |
+
only_cross_attention: bool = False,
|
101 |
+
upcast_attention: bool = False,
|
102 |
+
norm_type: str = "layer_norm",
|
103 |
+
norm_elementwise_affine: bool = True,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
self.use_linear_projection = use_linear_projection
|
107 |
+
self.num_attention_heads = num_attention_heads
|
108 |
+
self.attention_head_dim = attention_head_dim
|
109 |
+
inner_dim = num_attention_heads * attention_head_dim
|
110 |
+
|
111 |
+
# 1. Transformer2DModel can process both standard continuous images of
|
112 |
+
# shape `(batch_size, num_channels, width, height)` as well as
|
113 |
+
# quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
114 |
+
# Define whether input is continuous or discrete depending on configuration
|
115 |
+
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
116 |
+
self.is_input_vectorized = num_vector_embeds is not None
|
117 |
+
self.is_input_patches = in_channels is not None and patch_size is not None
|
118 |
+
|
119 |
+
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
120 |
+
deprecation_message = (
|
121 |
+
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
122 |
+
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
123 |
+
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
124 |
+
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
125 |
+
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
126 |
+
)
|
127 |
+
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
128 |
+
norm_type = "ada_norm"
|
129 |
+
|
130 |
+
if self.is_input_continuous and self.is_input_vectorized:
|
131 |
+
raise ValueError(
|
132 |
+
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
133 |
+
" sure that either `in_channels` or `num_vector_embeds` is None."
|
134 |
+
)
|
135 |
+
elif self.is_input_vectorized and self.is_input_patches:
|
136 |
+
raise ValueError(
|
137 |
+
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
138 |
+
" sure that either `num_vector_embeds` or `num_patches` is None."
|
139 |
+
)
|
140 |
+
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
141 |
+
raise ValueError(
|
142 |
+
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
143 |
+
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
144 |
+
)
|
145 |
+
|
146 |
+
# 2. Define input layers
|
147 |
+
if self.is_input_continuous:
|
148 |
+
self.in_channels = in_channels
|
149 |
+
|
150 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
151 |
+
if use_linear_projection:
|
152 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
153 |
+
else:
|
154 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
155 |
+
elif self.is_input_vectorized:
|
156 |
+
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
157 |
+
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
158 |
+
|
159 |
+
self.height = sample_size
|
160 |
+
self.width = sample_size
|
161 |
+
self.num_vector_embeds = num_vector_embeds
|
162 |
+
self.num_latent_pixels = self.height * self.width
|
163 |
+
|
164 |
+
self.latent_image_embedding = ImagePositionalEmbeddings(
|
165 |
+
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
166 |
+
)
|
167 |
+
elif self.is_input_patches:
|
168 |
+
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
169 |
+
|
170 |
+
self.height = sample_size
|
171 |
+
self.width = sample_size
|
172 |
+
|
173 |
+
self.patch_size = patch_size
|
174 |
+
self.pos_embed = PatchEmbed(
|
175 |
+
height=sample_size,
|
176 |
+
width=sample_size,
|
177 |
+
patch_size=patch_size,
|
178 |
+
in_channels=in_channels,
|
179 |
+
embed_dim=inner_dim,
|
180 |
+
)
|
181 |
+
|
182 |
+
# 3. Define transformers blocks
|
183 |
+
self.transformer_blocks = nn.ModuleList(
|
184 |
+
[
|
185 |
+
BasicTransformerBlock(
|
186 |
+
inner_dim,
|
187 |
+
num_attention_heads,
|
188 |
+
attention_head_dim,
|
189 |
+
dropout=dropout,
|
190 |
+
cross_attention_dim=cross_attention_dim,
|
191 |
+
activation_fn=activation_fn,
|
192 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
193 |
+
attention_bias=attention_bias,
|
194 |
+
only_cross_attention=only_cross_attention,
|
195 |
+
upcast_attention=upcast_attention,
|
196 |
+
norm_type=norm_type,
|
197 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
198 |
+
)
|
199 |
+
for d in range(num_layers)
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
# 4. Define output layers
|
204 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
205 |
+
if self.is_input_continuous:
|
206 |
+
# TODO: should use out_channels for continuous projections
|
207 |
+
if use_linear_projection:
|
208 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
209 |
+
else:
|
210 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
211 |
+
elif self.is_input_vectorized:
|
212 |
+
self.norm_out = nn.LayerNorm(inner_dim)
|
213 |
+
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
214 |
+
elif self.is_input_patches:
|
215 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
216 |
+
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
217 |
+
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
218 |
+
|
219 |
+
def forward(
|
220 |
+
self,
|
221 |
+
hidden_states: torch.Tensor,
|
222 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
223 |
+
timestep: Optional[torch.LongTensor] = None,
|
224 |
+
class_labels: Optional[torch.LongTensor] = None,
|
225 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
226 |
+
attention_mask: Optional[torch.Tensor] = None,
|
227 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
228 |
+
return_dict: bool = True,
|
229 |
+
):
|
230 |
+
"""
|
231 |
+
Args:
|
232 |
+
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
233 |
+
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
234 |
+
hidden_states
|
235 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
236 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
237 |
+
self-attention.
|
238 |
+
timestep ( `torch.LongTensor`, *optional*):
|
239 |
+
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
240 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
241 |
+
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class
|
242 |
+
labels conditioning.
|
243 |
+
attention_mask ( `torch.Tensor` of shape (batch size, num latent pixels), *optional* ).
|
244 |
+
Bias to add to attention scores.
|
245 |
+
encoder_attention_mask ( `torch.Tensor` of shape (batch size, num encoder tokens), *optional* ).
|
246 |
+
Bias to add to cross-attention scores.
|
247 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
248 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
252 |
+
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`.
|
253 |
+
When returning a tuple, the first element is the sample tensor.
|
254 |
+
"""
|
255 |
+
# 1. Input
|
256 |
+
if self.is_input_continuous:
|
257 |
+
batch, _, height, width = hidden_states.shape
|
258 |
+
residual = hidden_states
|
259 |
+
|
260 |
+
hidden_states = self.norm(hidden_states)
|
261 |
+
if not self.use_linear_projection:
|
262 |
+
hidden_states = self.proj_in(hidden_states)
|
263 |
+
inner_dim = hidden_states.shape[1]
|
264 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
265 |
+
else:
|
266 |
+
inner_dim = hidden_states.shape[1]
|
267 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
268 |
+
hidden_states = self.proj_in(hidden_states)
|
269 |
+
|
270 |
+
elif self.is_input_vectorized:
|
271 |
+
hidden_states = self.latent_image_embedding(hidden_states)
|
272 |
+
|
273 |
+
elif self.is_input_patches:
|
274 |
+
hidden_states = self.pos_embed(hidden_states)
|
275 |
+
|
276 |
+
# 2. Blocks
|
277 |
+
for block in self.transformer_blocks:
|
278 |
+
hidden_states = block(
|
279 |
+
hidden_states,
|
280 |
+
attention_mask=attention_mask,
|
281 |
+
encoder_hidden_states=encoder_hidden_states,
|
282 |
+
encoder_attention_mask=encoder_attention_mask,
|
283 |
+
timestep=timestep,
|
284 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
285 |
+
class_labels=class_labels,
|
286 |
+
)
|
287 |
+
|
288 |
+
# 3. Output
|
289 |
+
if self.is_input_continuous:
|
290 |
+
if not self.use_linear_projection:
|
291 |
+
hidden_states = hidden_states.reshape(
|
292 |
+
batch, height, width, inner_dim
|
293 |
+
).permute(0, 3, 1, 2).contiguous()
|
294 |
+
hidden_states = self.proj_out(hidden_states)
|
295 |
+
else:
|
296 |
+
hidden_states = self.proj_out(hidden_states)
|
297 |
+
hidden_states = hidden_states.reshape(
|
298 |
+
batch, height, width, inner_dim
|
299 |
+
).permute(0, 3, 1, 2).contiguous()
|
300 |
+
output = hidden_states + residual
|
301 |
+
|
302 |
+
elif self.is_input_vectorized:
|
303 |
+
hidden_states = self.norm_out(hidden_states)
|
304 |
+
logits = self.out(hidden_states)
|
305 |
+
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
306 |
+
logits = logits.permute(0, 2, 1)
|
307 |
+
|
308 |
+
# log(p(x_0))
|
309 |
+
output = F.log_softmax(logits.double(), dim=1).float()
|
310 |
+
|
311 |
+
elif self.is_input_patches:
|
312 |
+
# TODO: cleanup!
|
313 |
+
conditioning = self.transformer_blocks[0].norm1.emb(
|
314 |
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
315 |
+
)
|
316 |
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
317 |
+
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
318 |
+
hidden_states = self.proj_out_2(hidden_states)
|
319 |
+
|
320 |
+
# unpatchify
|
321 |
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
322 |
+
hidden_states = hidden_states.reshape(
|
323 |
+
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
324 |
+
)
|
325 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
326 |
+
output = hidden_states.reshape(
|
327 |
+
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
328 |
+
)
|
329 |
+
|
330 |
+
if not return_dict:
|
331 |
+
return (output,)
|
332 |
+
|
333 |
+
return Transformer2DModelOutput(sample=output)
|
diffusers/models/unet_2d.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from ..utils.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
from ..utils.outputs import BaseOutput
|
22 |
+
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
23 |
+
from .modeling_utils import ModelMixin
|
24 |
+
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class UNet2DOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
Args:
|
31 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
32 |
+
Hidden states output. Output of last layer of model.
|
33 |
+
"""
|
34 |
+
|
35 |
+
sample: torch.FloatTensor
|
36 |
+
|
37 |
+
|
38 |
+
class UNet2DModel(ModelMixin, ConfigMixin):
|
39 |
+
r"""
|
40 |
+
UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
|
41 |
+
|
42 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
43 |
+
implements for all the model (such as downloading or saving, etc.)
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
47 |
+
Height and width of input/output sample.
|
48 |
+
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
|
49 |
+
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
|
50 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
51 |
+
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
|
52 |
+
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
|
53 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to :
|
54 |
+
obj:`True`): Whether to flip sin to cos for fourier time embedding.
|
55 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
56 |
+
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
|
57 |
+
types.
|
58 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
59 |
+
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
60 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
61 |
+
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
|
62 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
63 |
+
obj:`(224, 448, 672, 896)`): Tuple of block output channels.
|
64 |
+
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
|
65 |
+
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
|
66 |
+
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
|
67 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
68 |
+
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
69 |
+
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
|
70 |
+
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
71 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
72 |
+
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
73 |
+
class_embed_type (`str`, *optional*, defaults to None):
|
74 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
75 |
+
`"timestep"`, or `"identity"`.
|
76 |
+
num_class_embeds (`int`, *optional*, defaults to None):
|
77 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
78 |
+
class conditioning with `class_embed_type` equal to `None`.
|
79 |
+
"""
|
80 |
+
|
81 |
+
@register_to_config
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
85 |
+
in_channels: int = 3,
|
86 |
+
out_channels: int = 3,
|
87 |
+
center_input_sample: bool = False,
|
88 |
+
time_embedding_type: str = "positional",
|
89 |
+
freq_shift: int = 0,
|
90 |
+
flip_sin_to_cos: bool = True,
|
91 |
+
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
92 |
+
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
|
93 |
+
block_out_channels: Tuple[int] = (224, 448, 672, 896),
|
94 |
+
layers_per_block: int = 2,
|
95 |
+
mid_block_scale_factor: float = 1,
|
96 |
+
downsample_padding: int = 1,
|
97 |
+
act_fn: str = "silu",
|
98 |
+
attention_head_dim: Optional[int] = 8,
|
99 |
+
norm_num_groups: int = 32,
|
100 |
+
norm_eps: float = 1e-5,
|
101 |
+
resnet_time_scale_shift: str = "default",
|
102 |
+
add_attention: bool = True,
|
103 |
+
class_embed_type: Optional[str] = None,
|
104 |
+
num_class_embeds: Optional[int] = None,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
|
108 |
+
self.sample_size = sample_size
|
109 |
+
time_embed_dim = block_out_channels[0] * 4
|
110 |
+
|
111 |
+
# Check inputs
|
112 |
+
if len(down_block_types) != len(up_block_types):
|
113 |
+
raise ValueError(
|
114 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
115 |
+
)
|
116 |
+
|
117 |
+
if len(block_out_channels) != len(down_block_types):
|
118 |
+
raise ValueError(
|
119 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
120 |
+
)
|
121 |
+
|
122 |
+
# input
|
123 |
+
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
124 |
+
|
125 |
+
# time
|
126 |
+
if time_embedding_type == "fourier":
|
127 |
+
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
|
128 |
+
timestep_input_dim = 2 * block_out_channels[0]
|
129 |
+
elif time_embedding_type == "positional":
|
130 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
131 |
+
timestep_input_dim = block_out_channels[0]
|
132 |
+
|
133 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
134 |
+
|
135 |
+
# class embedding
|
136 |
+
if class_embed_type is None and num_class_embeds is not None:
|
137 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
138 |
+
elif class_embed_type == "timestep":
|
139 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
140 |
+
elif class_embed_type == "identity":
|
141 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
142 |
+
else:
|
143 |
+
self.class_embedding = None
|
144 |
+
|
145 |
+
self.down_blocks = nn.ModuleList([])
|
146 |
+
self.mid_block = None
|
147 |
+
self.up_blocks = nn.ModuleList([])
|
148 |
+
|
149 |
+
# down
|
150 |
+
output_channel = block_out_channels[0]
|
151 |
+
for i, down_block_type in enumerate(down_block_types):
|
152 |
+
input_channel = output_channel
|
153 |
+
output_channel = block_out_channels[i]
|
154 |
+
is_final_block = i == len(block_out_channels) - 1
|
155 |
+
|
156 |
+
down_block = get_down_block(
|
157 |
+
down_block_type,
|
158 |
+
num_layers=layers_per_block,
|
159 |
+
in_channels=input_channel,
|
160 |
+
out_channels=output_channel,
|
161 |
+
temb_channels=time_embed_dim,
|
162 |
+
add_downsample=not is_final_block,
|
163 |
+
resnet_eps=norm_eps,
|
164 |
+
resnet_act_fn=act_fn,
|
165 |
+
resnet_groups=norm_num_groups,
|
166 |
+
attn_num_head_channels=attention_head_dim,
|
167 |
+
downsample_padding=downsample_padding,
|
168 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
169 |
+
)
|
170 |
+
self.down_blocks.append(down_block)
|
171 |
+
|
172 |
+
# mid
|
173 |
+
self.mid_block = UNetMidBlock2D(
|
174 |
+
in_channels=block_out_channels[-1],
|
175 |
+
temb_channels=time_embed_dim,
|
176 |
+
resnet_eps=norm_eps,
|
177 |
+
resnet_act_fn=act_fn,
|
178 |
+
output_scale_factor=mid_block_scale_factor,
|
179 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
180 |
+
attn_num_head_channels=attention_head_dim,
|
181 |
+
resnet_groups=norm_num_groups,
|
182 |
+
add_attention=add_attention,
|
183 |
+
)
|
184 |
+
|
185 |
+
# up
|
186 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
187 |
+
output_channel = reversed_block_out_channels[0]
|
188 |
+
for i, up_block_type in enumerate(up_block_types):
|
189 |
+
prev_output_channel = output_channel
|
190 |
+
output_channel = reversed_block_out_channels[i]
|
191 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
192 |
+
|
193 |
+
is_final_block = i == len(block_out_channels) - 1
|
194 |
+
|
195 |
+
up_block = get_up_block(
|
196 |
+
up_block_type,
|
197 |
+
num_layers=layers_per_block + 1,
|
198 |
+
in_channels=input_channel,
|
199 |
+
out_channels=output_channel,
|
200 |
+
prev_output_channel=prev_output_channel,
|
201 |
+
temb_channels=time_embed_dim,
|
202 |
+
add_upsample=not is_final_block,
|
203 |
+
resnet_eps=norm_eps,
|
204 |
+
resnet_act_fn=act_fn,
|
205 |
+
resnet_groups=norm_num_groups,
|
206 |
+
attn_num_head_channels=attention_head_dim,
|
207 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
208 |
+
)
|
209 |
+
self.up_blocks.append(up_block)
|
210 |
+
prev_output_channel = output_channel
|
211 |
+
|
212 |
+
# out
|
213 |
+
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
|
214 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
|
215 |
+
self.conv_act = nn.SiLU()
|
216 |
+
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
217 |
+
|
218 |
+
def forward(
|
219 |
+
self,
|
220 |
+
sample: torch.FloatTensor,
|
221 |
+
timestep: Union[torch.Tensor, float, int],
|
222 |
+
class_labels: Optional[torch.Tensor] = None,
|
223 |
+
return_dict: bool = True,
|
224 |
+
) -> Union[UNet2DOutput, Tuple]:
|
225 |
+
r"""
|
226 |
+
Args:
|
227 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
228 |
+
timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
|
229 |
+
class_labels (`torch.FloatTensor`, *optional*, defaults to `None`):
|
230 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
231 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
232 |
+
Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
[`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
|
236 |
+
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
237 |
+
"""
|
238 |
+
# 0. center input if necessary
|
239 |
+
if self.config.center_input_sample:
|
240 |
+
sample = 2 * sample - 1.0
|
241 |
+
|
242 |
+
# 1. time
|
243 |
+
timesteps = timestep
|
244 |
+
if not torch.is_tensor(timesteps):
|
245 |
+
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
246 |
+
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
247 |
+
timesteps = timesteps[None].to(sample.device)
|
248 |
+
|
249 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
250 |
+
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
251 |
+
|
252 |
+
t_emb = self.time_proj(timesteps)
|
253 |
+
|
254 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
255 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
256 |
+
# there might be better ways to encapsulate this.
|
257 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
258 |
+
emb = self.time_embedding(t_emb)
|
259 |
+
|
260 |
+
if self.class_embedding is not None:
|
261 |
+
if class_labels is None:
|
262 |
+
raise ValueError("class_labels should be provided when doing class conditioning")
|
263 |
+
|
264 |
+
if self.config.class_embed_type == "timestep":
|
265 |
+
class_labels = self.time_proj(class_labels)
|
266 |
+
|
267 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
268 |
+
emb = emb + class_emb
|
269 |
+
|
270 |
+
# 2. pre-process
|
271 |
+
skip_sample = sample
|
272 |
+
sample = self.conv_in(sample)
|
273 |
+
|
274 |
+
# 3. down
|
275 |
+
down_block_res_samples = (sample,)
|
276 |
+
for downsample_block in self.down_blocks:
|
277 |
+
if hasattr(downsample_block, "skip_conv"):
|
278 |
+
sample, res_samples, skip_sample = downsample_block(
|
279 |
+
hidden_states=sample, temb=emb, skip_sample=skip_sample
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
283 |
+
|
284 |
+
down_block_res_samples += res_samples
|
285 |
+
|
286 |
+
# 4. mid
|
287 |
+
sample = self.mid_block(sample, emb)
|
288 |
+
|
289 |
+
# 5. up
|
290 |
+
skip_sample = None
|
291 |
+
for upsample_block in self.up_blocks:
|
292 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
293 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
294 |
+
|
295 |
+
if hasattr(upsample_block, "skip_conv"):
|
296 |
+
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
|
297 |
+
else:
|
298 |
+
sample = upsample_block(sample, res_samples, emb)
|
299 |
+
|
300 |
+
# 6. post-process
|
301 |
+
sample = self.conv_norm_out(sample)
|
302 |
+
sample = self.conv_act(sample)
|
303 |
+
sample = self.conv_out(sample)
|
304 |
+
|
305 |
+
if skip_sample is not None:
|
306 |
+
sample += skip_sample
|
307 |
+
|
308 |
+
if self.config.time_embedding_type == "fourier":
|
309 |
+
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
310 |
+
sample = sample / timesteps
|
311 |
+
|
312 |
+
if not return_dict:
|
313 |
+
return (sample,)
|
314 |
+
|
315 |
+
return UNet2DOutput(sample=sample)
|
diffusers/models/unet_2d_blocks.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffusers/models/unet_2d_condition.py
ADDED
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from ..utils.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..utils.outputs import BaseOutput
|
24 |
+
from .loaders import UNet2DConditionLoadersMixin
|
25 |
+
from .activations import get_activation
|
26 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
27 |
+
from .embeddings import (
|
28 |
+
GaussianFourierProjection,
|
29 |
+
TextImageProjection,
|
30 |
+
TextImageTimeEmbedding,
|
31 |
+
TextTimeEmbedding,
|
32 |
+
TimestepEmbedding,
|
33 |
+
Timesteps,
|
34 |
+
)
|
35 |
+
from .modeling_utils import ModelMixin
|
36 |
+
from .unet_2d_blocks import (
|
37 |
+
CrossAttnDownBlock2D,
|
38 |
+
CrossAttnUpBlock2D,
|
39 |
+
DownBlock2D,
|
40 |
+
UNetMidBlock2DCrossAttn,
|
41 |
+
UNetMidBlock2DSimpleCrossAttn,
|
42 |
+
UpBlock2D,
|
43 |
+
get_down_block,
|
44 |
+
get_up_block,
|
45 |
+
)
|
46 |
+
from ..utils.logging import get_logger
|
47 |
+
|
48 |
+
logger = get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class UNet2DConditionOutput(BaseOutput):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
56 |
+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
57 |
+
"""
|
58 |
+
|
59 |
+
sample: torch.FloatTensor
|
60 |
+
|
61 |
+
|
62 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
63 |
+
r"""
|
64 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
65 |
+
and returns sample shaped output.
|
66 |
+
|
67 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
68 |
+
implements for all the models (such as downloading or saving, etc.)
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
72 |
+
Height and width of input/output sample.
|
73 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
74 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
75 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
76 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
77 |
+
Whether to flip the sin to cos in the time embedding.
|
78 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
79 |
+
down_block_types (`Tuple[str]`, *optional*,
|
80 |
+
defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
81 |
+
The tuple of downsample blocks to use.
|
82 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
83 |
+
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
|
84 |
+
mid block layer if `None`.
|
85 |
+
up_block_types (`Tuple[str]`, *optional*,
|
86 |
+
defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
87 |
+
The tuple of upsample blocks to use.
|
88 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
89 |
+
Whether to include self-attention in the basic transformer blocks, see
|
90 |
+
[`~models.attention.BasicTransformerBlock`].
|
91 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
92 |
+
The tuple of output channels for each block.
|
93 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
94 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
95 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
96 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
97 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
98 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
99 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
100 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
101 |
+
The dimension of the cross attention features.
|
102 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
103 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
104 |
+
dimension to `cross_attention_dim`.
|
105 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to None):
|
106 |
+
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
|
107 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
108 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
109 |
+
num_attention_heads (`int`, *optional*):
|
110 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
111 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
112 |
+
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
113 |
+
class_embed_type (`str`, *optional*, defaults to None):
|
114 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
115 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
116 |
+
addition_embed_type (`str`, *optional*, defaults to None):
|
117 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
118 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
119 |
+
num_class_embeds (`int`, *optional*, defaults to None):
|
120 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
121 |
+
class conditioning with `class_embed_type` equal to `None`.
|
122 |
+
time_embedding_type (`str`, *optional*, default to `positional`):
|
123 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
124 |
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
125 |
+
An optional override for the dimension of the projected time embedding.
|
126 |
+
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
127 |
+
Optional activation function to use on the time embeddings only one time before they as passed to the rest
|
128 |
+
of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
129 |
+
timestep_post_act (`str, *optional*, default to `None`):
|
130 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
131 |
+
time_cond_proj_dim (`int`, *optional*, default to `None`):
|
132 |
+
The dimension of `cond_proj` layer in timestep embedding.
|
133 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
134 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
135 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
136 |
+
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
137 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
138 |
+
embeddings with the class embeddings.
|
139 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
140 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
141 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
142 |
+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
|
143 |
+
default to `False`.
|
144 |
+
"""
|
145 |
+
|
146 |
+
_supports_gradient_checkpointing = True
|
147 |
+
|
148 |
+
@register_to_config
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
sample_size: Optional[int] = None,
|
152 |
+
in_channels: int = 4,
|
153 |
+
out_channels: int = 4,
|
154 |
+
center_input_sample: bool = False,
|
155 |
+
flip_sin_to_cos: bool = True,
|
156 |
+
freq_shift: int = 0,
|
157 |
+
down_block_types: Tuple[str] = (
|
158 |
+
"CrossAttnDownBlock2D",
|
159 |
+
"CrossAttnDownBlock2D",
|
160 |
+
"CrossAttnDownBlock2D",
|
161 |
+
"DownBlock2D",
|
162 |
+
),
|
163 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
164 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
165 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
166 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
167 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
168 |
+
downsample_padding: int = 1,
|
169 |
+
mid_block_scale_factor: float = 1,
|
170 |
+
act_fn: str = "silu",
|
171 |
+
norm_num_groups: Optional[int] = 32,
|
172 |
+
norm_eps: float = 1e-5,
|
173 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
174 |
+
encoder_hid_dim: Optional[int] = None,
|
175 |
+
encoder_hid_dim_type: Optional[str] = None,
|
176 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
177 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
178 |
+
dual_cross_attention: bool = False,
|
179 |
+
use_linear_projection: bool = False,
|
180 |
+
class_embed_type: Optional[str] = None,
|
181 |
+
addition_embed_type: Optional[str] = None,
|
182 |
+
num_class_embeds: Optional[int] = None,
|
183 |
+
upcast_attention: bool = False,
|
184 |
+
resnet_time_scale_shift: str = "default",
|
185 |
+
resnet_skip_time_act: bool = False,
|
186 |
+
resnet_out_scale_factor: int = 1.0,
|
187 |
+
time_embedding_type: str = "positional",
|
188 |
+
time_embedding_dim: Optional[int] = None,
|
189 |
+
time_embedding_act_fn: Optional[str] = None,
|
190 |
+
timestep_post_act: Optional[str] = None,
|
191 |
+
time_cond_proj_dim: Optional[int] = None,
|
192 |
+
conv_in_kernel: int = 3,
|
193 |
+
conv_out_kernel: int = 3,
|
194 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
195 |
+
class_embeddings_concat: bool = False,
|
196 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
197 |
+
cross_attention_norm: Optional[str] = None,
|
198 |
+
addition_embed_type_num_heads=64,
|
199 |
+
):
|
200 |
+
super().__init__()
|
201 |
+
|
202 |
+
self.sample_size = sample_size
|
203 |
+
|
204 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
205 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
206 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
207 |
+
# when this library was created. The incorrect naming was only discovered much later in
|
208 |
+
# https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
209 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
210 |
+
# which is why we correct for the naming here.
|
211 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
212 |
+
|
213 |
+
# Check inputs
|
214 |
+
if len(down_block_types) != len(up_block_types):
|
215 |
+
raise ValueError(
|
216 |
+
"Must provide the same number of `down_block_types` as `up_block_types`. "
|
217 |
+
f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
218 |
+
)
|
219 |
+
|
220 |
+
if len(block_out_channels) != len(down_block_types):
|
221 |
+
raise ValueError(
|
222 |
+
"Must provide the same number of `block_out_channels` as `down_block_types`. "
|
223 |
+
f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
224 |
+
)
|
225 |
+
|
226 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
227 |
+
raise ValueError(
|
228 |
+
"Must provide the same number of `only_cross_attention` as `down_block_types`. "
|
229 |
+
f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
230 |
+
)
|
231 |
+
|
232 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
233 |
+
raise ValueError(
|
234 |
+
"Must provide the same number of `num_attention_heads` as `down_block_types`. "
|
235 |
+
f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
236 |
+
)
|
237 |
+
|
238 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
239 |
+
raise ValueError(
|
240 |
+
"Must provide the same number of `attention_head_dim` as `down_block_types`. "
|
241 |
+
f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
242 |
+
)
|
243 |
+
|
244 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
245 |
+
raise ValueError(
|
246 |
+
"Must provide the same number of `cross_attention_dim` as `down_block_types`. "
|
247 |
+
f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
248 |
+
)
|
249 |
+
|
250 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
251 |
+
raise ValueError(
|
252 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. "
|
253 |
+
f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
254 |
+
)
|
255 |
+
|
256 |
+
# input
|
257 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
258 |
+
self.conv_in = nn.Conv2d(
|
259 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
260 |
+
)
|
261 |
+
|
262 |
+
# time
|
263 |
+
if time_embedding_type == "fourier":
|
264 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
265 |
+
if time_embed_dim % 2 != 0:
|
266 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
267 |
+
self.time_proj = GaussianFourierProjection(
|
268 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
269 |
+
)
|
270 |
+
timestep_input_dim = time_embed_dim
|
271 |
+
elif time_embedding_type == "positional":
|
272 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
273 |
+
|
274 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
275 |
+
timestep_input_dim = block_out_channels[0]
|
276 |
+
else:
|
277 |
+
raise ValueError(
|
278 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
279 |
+
)
|
280 |
+
|
281 |
+
self.time_embedding = TimestepEmbedding(
|
282 |
+
timestep_input_dim,
|
283 |
+
time_embed_dim,
|
284 |
+
act_fn=act_fn,
|
285 |
+
post_act_fn=timestep_post_act,
|
286 |
+
cond_proj_dim=time_cond_proj_dim,
|
287 |
+
)
|
288 |
+
|
289 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
290 |
+
encoder_hid_dim_type = "text_proj"
|
291 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
292 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
293 |
+
|
294 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
295 |
+
raise ValueError(
|
296 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
297 |
+
)
|
298 |
+
|
299 |
+
if encoder_hid_dim_type == "text_proj":
|
300 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
301 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
302 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
303 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently
|
304 |
+
# only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
305 |
+
self.encoder_hid_proj = TextImageProjection(
|
306 |
+
text_embed_dim=encoder_hid_dim,
|
307 |
+
image_embed_dim=cross_attention_dim,
|
308 |
+
cross_attention_dim=cross_attention_dim,
|
309 |
+
)
|
310 |
+
|
311 |
+
elif encoder_hid_dim_type is not None:
|
312 |
+
raise ValueError(
|
313 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
314 |
+
)
|
315 |
+
else:
|
316 |
+
self.encoder_hid_proj = None
|
317 |
+
|
318 |
+
# class embedding
|
319 |
+
if class_embed_type is None and num_class_embeds is not None:
|
320 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
321 |
+
elif class_embed_type == "timestep":
|
322 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
323 |
+
elif class_embed_type == "identity":
|
324 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
325 |
+
elif class_embed_type == "projection":
|
326 |
+
if projection_class_embeddings_input_dim is None:
|
327 |
+
raise ValueError(
|
328 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
329 |
+
)
|
330 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
331 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
332 |
+
# 2. it projects from an arbitrary input dimension.
|
333 |
+
#
|
334 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
335 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
336 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
337 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
338 |
+
elif class_embed_type == "simple_projection":
|
339 |
+
if projection_class_embeddings_input_dim is None:
|
340 |
+
raise ValueError(
|
341 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
342 |
+
)
|
343 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
344 |
+
else:
|
345 |
+
self.class_embedding = None
|
346 |
+
|
347 |
+
if addition_embed_type == "text":
|
348 |
+
if encoder_hid_dim is not None:
|
349 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
350 |
+
else:
|
351 |
+
text_time_embedding_from_dim = cross_attention_dim
|
352 |
+
|
353 |
+
self.add_embedding = TextTimeEmbedding(
|
354 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
355 |
+
)
|
356 |
+
elif addition_embed_type == "text_image":
|
357 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__
|
358 |
+
# too much, they are set to `cross_attention_dim` here as this is exactly the required dimension for the
|
359 |
+
# currently only use case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
360 |
+
self.add_embedding = TextImageTimeEmbedding(
|
361 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
362 |
+
)
|
363 |
+
elif addition_embed_type is not None:
|
364 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
365 |
+
|
366 |
+
if time_embedding_act_fn is None:
|
367 |
+
self.time_embed_act = None
|
368 |
+
else:
|
369 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
370 |
+
|
371 |
+
self.down_blocks = nn.ModuleList([])
|
372 |
+
self.up_blocks = nn.ModuleList([])
|
373 |
+
|
374 |
+
if isinstance(only_cross_attention, bool):
|
375 |
+
if mid_block_only_cross_attention is None:
|
376 |
+
mid_block_only_cross_attention = only_cross_attention
|
377 |
+
|
378 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
379 |
+
|
380 |
+
if mid_block_only_cross_attention is None:
|
381 |
+
mid_block_only_cross_attention = False
|
382 |
+
|
383 |
+
if isinstance(num_attention_heads, int):
|
384 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
385 |
+
|
386 |
+
if isinstance(attention_head_dim, int):
|
387 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
388 |
+
|
389 |
+
if isinstance(cross_attention_dim, int):
|
390 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
391 |
+
|
392 |
+
if isinstance(layers_per_block, int):
|
393 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
394 |
+
|
395 |
+
if class_embeddings_concat:
|
396 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
397 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
398 |
+
# regular time embeddings
|
399 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
400 |
+
else:
|
401 |
+
blocks_time_embed_dim = time_embed_dim
|
402 |
+
|
403 |
+
# down
|
404 |
+
output_channel = block_out_channels[0]
|
405 |
+
for i, down_block_type in enumerate(down_block_types):
|
406 |
+
input_channel = output_channel
|
407 |
+
output_channel = block_out_channels[i]
|
408 |
+
is_final_block = i == len(block_out_channels) - 1
|
409 |
+
|
410 |
+
down_block = get_down_block(
|
411 |
+
down_block_type,
|
412 |
+
num_layers=layers_per_block[i],
|
413 |
+
in_channels=input_channel,
|
414 |
+
out_channels=output_channel,
|
415 |
+
temb_channels=blocks_time_embed_dim,
|
416 |
+
add_downsample=not is_final_block,
|
417 |
+
resnet_eps=norm_eps,
|
418 |
+
resnet_act_fn=act_fn,
|
419 |
+
resnet_groups=norm_num_groups,
|
420 |
+
cross_attention_dim=cross_attention_dim[i],
|
421 |
+
num_attention_heads=num_attention_heads[i],
|
422 |
+
downsample_padding=downsample_padding,
|
423 |
+
dual_cross_attention=dual_cross_attention,
|
424 |
+
use_linear_projection=use_linear_projection,
|
425 |
+
only_cross_attention=only_cross_attention[i],
|
426 |
+
upcast_attention=upcast_attention,
|
427 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
428 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
429 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
430 |
+
cross_attention_norm=cross_attention_norm,
|
431 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
432 |
+
)
|
433 |
+
self.down_blocks.append(down_block)
|
434 |
+
|
435 |
+
# mid
|
436 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
437 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
438 |
+
in_channels=block_out_channels[-1],
|
439 |
+
temb_channels=blocks_time_embed_dim,
|
440 |
+
resnet_eps=norm_eps,
|
441 |
+
resnet_act_fn=act_fn,
|
442 |
+
output_scale_factor=mid_block_scale_factor,
|
443 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
444 |
+
cross_attention_dim=cross_attention_dim[-1],
|
445 |
+
num_attention_heads=num_attention_heads[-1],
|
446 |
+
resnet_groups=norm_num_groups,
|
447 |
+
dual_cross_attention=dual_cross_attention,
|
448 |
+
use_linear_projection=use_linear_projection,
|
449 |
+
upcast_attention=upcast_attention,
|
450 |
+
)
|
451 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
452 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
453 |
+
in_channels=block_out_channels[-1],
|
454 |
+
temb_channels=blocks_time_embed_dim,
|
455 |
+
resnet_eps=norm_eps,
|
456 |
+
resnet_act_fn=act_fn,
|
457 |
+
output_scale_factor=mid_block_scale_factor,
|
458 |
+
cross_attention_dim=cross_attention_dim[-1],
|
459 |
+
attention_head_dim=attention_head_dim[-1],
|
460 |
+
resnet_groups=norm_num_groups,
|
461 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
462 |
+
skip_time_act=resnet_skip_time_act,
|
463 |
+
only_cross_attention=mid_block_only_cross_attention,
|
464 |
+
cross_attention_norm=cross_attention_norm,
|
465 |
+
)
|
466 |
+
elif mid_block_type is None:
|
467 |
+
self.mid_block = None
|
468 |
+
else:
|
469 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
470 |
+
|
471 |
+
# count how many layers upsample the images
|
472 |
+
self.num_upsamplers = 0
|
473 |
+
|
474 |
+
# up
|
475 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
476 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
477 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
478 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
479 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
480 |
+
|
481 |
+
output_channel = reversed_block_out_channels[0]
|
482 |
+
for i, up_block_type in enumerate(up_block_types):
|
483 |
+
is_final_block = i == len(block_out_channels) - 1
|
484 |
+
|
485 |
+
prev_output_channel = output_channel
|
486 |
+
output_channel = reversed_block_out_channels[i]
|
487 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
488 |
+
|
489 |
+
# add upsample block for all BUT final layer
|
490 |
+
if not is_final_block:
|
491 |
+
add_upsample = True
|
492 |
+
self.num_upsamplers += 1
|
493 |
+
else:
|
494 |
+
add_upsample = False
|
495 |
+
|
496 |
+
up_block = get_up_block(
|
497 |
+
up_block_type,
|
498 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
499 |
+
in_channels=input_channel,
|
500 |
+
out_channels=output_channel,
|
501 |
+
prev_output_channel=prev_output_channel,
|
502 |
+
temb_channels=blocks_time_embed_dim,
|
503 |
+
add_upsample=add_upsample,
|
504 |
+
resnet_eps=norm_eps,
|
505 |
+
resnet_act_fn=act_fn,
|
506 |
+
resnet_groups=norm_num_groups,
|
507 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
508 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
509 |
+
dual_cross_attention=dual_cross_attention,
|
510 |
+
use_linear_projection=use_linear_projection,
|
511 |
+
only_cross_attention=only_cross_attention[i],
|
512 |
+
upcast_attention=upcast_attention,
|
513 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
514 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
515 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
516 |
+
cross_attention_norm=cross_attention_norm,
|
517 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
518 |
+
)
|
519 |
+
self.up_blocks.append(up_block)
|
520 |
+
prev_output_channel = output_channel
|
521 |
+
|
522 |
+
# out
|
523 |
+
if norm_num_groups is not None:
|
524 |
+
self.conv_norm_out = nn.GroupNorm(
|
525 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
526 |
+
)
|
527 |
+
|
528 |
+
self.conv_act = get_activation(act_fn)
|
529 |
+
|
530 |
+
else:
|
531 |
+
self.conv_norm_out = None
|
532 |
+
self.conv_act = None
|
533 |
+
|
534 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
535 |
+
self.conv_out = nn.Conv2d(
|
536 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
537 |
+
)
|
538 |
+
|
539 |
+
@property
|
540 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
541 |
+
r"""
|
542 |
+
Returns:
|
543 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
544 |
+
indexed by its weight name.
|
545 |
+
"""
|
546 |
+
# set recursively
|
547 |
+
processors = {}
|
548 |
+
|
549 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
550 |
+
if hasattr(module, "set_processor"):
|
551 |
+
processors[f"{name}.processor"] = module.processor
|
552 |
+
|
553 |
+
for sub_name, child in module.named_children():
|
554 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
555 |
+
|
556 |
+
return processors
|
557 |
+
|
558 |
+
for name, module in self.named_children():
|
559 |
+
fn_recursive_add_processors(name, module, processors)
|
560 |
+
|
561 |
+
return processors
|
562 |
+
|
563 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
564 |
+
r"""
|
565 |
+
Parameters:
|
566 |
+
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
567 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
568 |
+
of **all** `Attention` layers.
|
569 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor.
|
570 |
+
This is strongly recommended when setting trainable attention processors.
|
571 |
+
"""
|
572 |
+
count = len(self.attn_processors.keys())
|
573 |
+
|
574 |
+
if isinstance(processor, dict) and len(processor) != count:
|
575 |
+
raise ValueError(
|
576 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
577 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
578 |
+
)
|
579 |
+
|
580 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
581 |
+
if hasattr(module, "set_processor"):
|
582 |
+
if not isinstance(processor, dict):
|
583 |
+
module.set_processor(processor)
|
584 |
+
else:
|
585 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
586 |
+
|
587 |
+
for sub_name, child in module.named_children():
|
588 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
589 |
+
|
590 |
+
for name, module in self.named_children():
|
591 |
+
fn_recursive_attn_processor(name, module, processor)
|
592 |
+
|
593 |
+
def set_default_attn_processor(self):
|
594 |
+
"""
|
595 |
+
Disables custom attention processors and sets the default attention implementation.
|
596 |
+
"""
|
597 |
+
self.set_attn_processor(AttnProcessor())
|
598 |
+
|
599 |
+
def set_attention_slice(self, slice_size):
|
600 |
+
r"""
|
601 |
+
Enable sliced attention computation.
|
602 |
+
|
603 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
604 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
608 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
609 |
+
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
|
610 |
+
provided, uses as many slices as `num_attention_heads // slice_size`. In this case,
|
611 |
+
`num_attention_heads` must be a multiple of `slice_size`.
|
612 |
+
"""
|
613 |
+
sliceable_head_dims = []
|
614 |
+
|
615 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
616 |
+
if hasattr(module, "set_attention_slice"):
|
617 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
618 |
+
|
619 |
+
for child in module.children():
|
620 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
621 |
+
|
622 |
+
# retrieve number of attention layers
|
623 |
+
for module in self.children():
|
624 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
625 |
+
|
626 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
627 |
+
|
628 |
+
if slice_size == "auto":
|
629 |
+
# half the attention head size is usually a good trade-off between
|
630 |
+
# speed and memory
|
631 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
632 |
+
elif slice_size == "max":
|
633 |
+
# make smallest slice possible
|
634 |
+
slice_size = num_sliceable_layers * [1]
|
635 |
+
|
636 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
637 |
+
|
638 |
+
if len(slice_size) != len(sliceable_head_dims):
|
639 |
+
raise ValueError(
|
640 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
641 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
642 |
+
)
|
643 |
+
|
644 |
+
for i in range(len(slice_size)):
|
645 |
+
size = slice_size[i]
|
646 |
+
dim = sliceable_head_dims[i]
|
647 |
+
if size is not None and size > dim:
|
648 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
649 |
+
|
650 |
+
# Recursively walk through all the children.
|
651 |
+
# Any children which exposes the set_attention_slice method
|
652 |
+
# gets the message
|
653 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
654 |
+
if hasattr(module, "set_attention_slice"):
|
655 |
+
module.set_attention_slice(slice_size.pop())
|
656 |
+
|
657 |
+
for child in module.children():
|
658 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
659 |
+
|
660 |
+
reversed_slice_size = list(reversed(slice_size))
|
661 |
+
for module in self.children():
|
662 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
663 |
+
|
664 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
665 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
666 |
+
module.gradient_checkpointing = value
|
667 |
+
|
668 |
+
def forward(
|
669 |
+
self,
|
670 |
+
sample: torch.FloatTensor,
|
671 |
+
timestep: Union[torch.Tensor, float, int],
|
672 |
+
encoder_hidden_states: torch.Tensor,
|
673 |
+
class_labels: Optional[torch.Tensor] = None,
|
674 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
675 |
+
attention_mask: Optional[torch.Tensor] = None,
|
676 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
677 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
678 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
679 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
680 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
681 |
+
return_dict: bool = True,
|
682 |
+
**kwargs
|
683 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
684 |
+
r"""
|
685 |
+
Args:
|
686 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
687 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
688 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
689 |
+
encoder_attention_mask (`torch.Tensor`):
|
690 |
+
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
|
691 |
+
discard. Mask will be converted into a bias, which adds large negative values to attention scores
|
692 |
+
corresponding to "discard" tokens.
|
693 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
694 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
695 |
+
cross_attention_kwargs (`dict`, *optional*):
|
696 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
697 |
+
`self.processor` in [diffusers.cross_attention]
|
698 |
+
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
699 |
+
added_cond_kwargs (`dict`, *optional*):
|
700 |
+
A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
|
701 |
+
embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
|
702 |
+
`addition_embed_type` for more information.
|
703 |
+
|
704 |
+
Returns:
|
705 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
706 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
707 |
+
returning a tuple, the first element is the sample tensor.
|
708 |
+
"""
|
709 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
710 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
711 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
712 |
+
# on the fly if necessary.
|
713 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
714 |
+
|
715 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
716 |
+
forward_upsample_size = False
|
717 |
+
upsample_size = None
|
718 |
+
|
719 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
720 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
721 |
+
forward_upsample_size = True
|
722 |
+
|
723 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
724 |
+
# expects mask of shape:
|
725 |
+
# [batch, key_tokens]
|
726 |
+
# adds singleton query_tokens dimension:
|
727 |
+
# [batch, 1, key_tokens]
|
728 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
729 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
730 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
731 |
+
if attention_mask is not None:
|
732 |
+
# assume that mask is expressed as:
|
733 |
+
# (1 = keep, 0 = discard)
|
734 |
+
# convert mask into a bias that can be added to attention scores:
|
735 |
+
# (keep = +0, discard = -10000.0)
|
736 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
737 |
+
attention_mask = attention_mask.unsqueeze(1)
|
738 |
+
|
739 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
740 |
+
if encoder_attention_mask is not None:
|
741 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
742 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
743 |
+
|
744 |
+
# 0. center input if necessary
|
745 |
+
if self.config.center_input_sample:
|
746 |
+
sample = 2 * sample - 1.0
|
747 |
+
|
748 |
+
# 1. time
|
749 |
+
timesteps = timestep
|
750 |
+
if not torch.is_tensor(timesteps):
|
751 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
752 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
753 |
+
is_mps = sample.device.type == "mps"
|
754 |
+
if isinstance(timestep, float):
|
755 |
+
dtype = torch.float32 if is_mps else torch.float64
|
756 |
+
else:
|
757 |
+
dtype = torch.int32 if is_mps else torch.int64
|
758 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
759 |
+
elif len(timesteps.shape) == 0:
|
760 |
+
timesteps = timesteps[None].to(sample.device)
|
761 |
+
|
762 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
763 |
+
timesteps = timesteps.expand(sample.shape[0])
|
764 |
+
|
765 |
+
t_emb = self.time_proj(timesteps)
|
766 |
+
|
767 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
768 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
769 |
+
# there might be better ways to encapsulate this.
|
770 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
771 |
+
|
772 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
773 |
+
|
774 |
+
if self.class_embedding is not None:
|
775 |
+
if class_labels is None:
|
776 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
777 |
+
|
778 |
+
if self.config.class_embed_type == "timestep":
|
779 |
+
class_labels = self.time_proj(class_labels)
|
780 |
+
|
781 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
782 |
+
# there might be better ways to encapsulate this.
|
783 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
784 |
+
|
785 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
786 |
+
|
787 |
+
if self.config.class_embeddings_concat:
|
788 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
789 |
+
else:
|
790 |
+
emb = emb + class_emb
|
791 |
+
|
792 |
+
if self.config.addition_embed_type == "text":
|
793 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
794 |
+
emb = emb + aug_emb
|
795 |
+
elif self.config.addition_embed_type == "text_image":
|
796 |
+
# Kadinsky 2.1 - style
|
797 |
+
if "image_embeds" not in added_cond_kwargs:
|
798 |
+
raise ValueError(
|
799 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
|
800 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
801 |
+
)
|
802 |
+
|
803 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
804 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
805 |
+
|
806 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
807 |
+
emb = emb + aug_emb
|
808 |
+
|
809 |
+
if self.time_embed_act is not None:
|
810 |
+
emb = self.time_embed_act(emb)
|
811 |
+
|
812 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
813 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
814 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
815 |
+
# Kadinsky 2.1 - style
|
816 |
+
if "image_embeds" not in added_cond_kwargs:
|
817 |
+
raise ValueError(
|
818 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
|
819 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
820 |
+
)
|
821 |
+
|
822 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
823 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
824 |
+
|
825 |
+
# 2. pre-process
|
826 |
+
sample = self.conv_in(sample)
|
827 |
+
|
828 |
+
# 3. down
|
829 |
+
down_block_res_samples = (sample,)
|
830 |
+
for downsample_block in self.down_blocks:
|
831 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
832 |
+
sample, res_samples = downsample_block(
|
833 |
+
hidden_states=sample,
|
834 |
+
temb=emb,
|
835 |
+
encoder_hidden_states=encoder_hidden_states,
|
836 |
+
attention_mask=attention_mask,
|
837 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
838 |
+
encoder_attention_mask=encoder_attention_mask,
|
839 |
+
)
|
840 |
+
else:
|
841 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
842 |
+
|
843 |
+
down_block_res_samples += res_samples
|
844 |
+
|
845 |
+
if down_block_additional_residuals is not None:
|
846 |
+
new_down_block_res_samples = ()
|
847 |
+
|
848 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
849 |
+
down_block_res_samples, down_block_additional_residuals
|
850 |
+
):
|
851 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
852 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
853 |
+
|
854 |
+
down_block_res_samples = new_down_block_res_samples
|
855 |
+
|
856 |
+
# 4. mid
|
857 |
+
if self.mid_block is not None:
|
858 |
+
sample = self.mid_block(
|
859 |
+
sample,
|
860 |
+
emb,
|
861 |
+
encoder_hidden_states=encoder_hidden_states,
|
862 |
+
attention_mask=attention_mask,
|
863 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
864 |
+
encoder_attention_mask=encoder_attention_mask,
|
865 |
+
)
|
866 |
+
|
867 |
+
if mid_block_additional_residual is not None:
|
868 |
+
sample = sample + mid_block_additional_residual
|
869 |
+
|
870 |
+
# 5. up
|
871 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
872 |
+
is_final_block = i == len(self.up_blocks) - 1
|
873 |
+
|
874 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
875 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
876 |
+
|
877 |
+
# if we have not reached the final block and need to forward the
|
878 |
+
# upsample size, we do it here
|
879 |
+
if not is_final_block and forward_upsample_size:
|
880 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
881 |
+
|
882 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
883 |
+
sample = upsample_block(
|
884 |
+
hidden_states=sample,
|
885 |
+
temb=emb,
|
886 |
+
res_hidden_states_tuple=res_samples,
|
887 |
+
encoder_hidden_states=encoder_hidden_states,
|
888 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
889 |
+
upsample_size=upsample_size,
|
890 |
+
attention_mask=attention_mask,
|
891 |
+
encoder_attention_mask=encoder_attention_mask,
|
892 |
+
)
|
893 |
+
else:
|
894 |
+
sample = upsample_block(
|
895 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
896 |
+
)
|
897 |
+
|
898 |
+
# 6. post-process
|
899 |
+
if self.conv_norm_out:
|
900 |
+
sample = self.conv_norm_out(sample)
|
901 |
+
sample = self.conv_act(sample)
|
902 |
+
sample = self.conv_out(sample)
|
903 |
+
|
904 |
+
if not return_dict:
|
905 |
+
return (sample,)
|
906 |
+
|
907 |
+
return UNet2DConditionOutput(sample=sample)
|
diffusers/models/unet_2d_condition_guided.py
ADDED
@@ -0,0 +1,945 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
|
22 |
+
from ..utils.configuration_utils import ConfigMixin, register_to_config
|
23 |
+
from ..utils import logging
|
24 |
+
from .loaders import UNet2DConditionLoadersMixin
|
25 |
+
from .activations import get_activation
|
26 |
+
from .attention_processor import AttentionProcessor, AttnProcessor
|
27 |
+
from .embeddings import (
|
28 |
+
GaussianFourierProjection,
|
29 |
+
TextImageProjection,
|
30 |
+
TextImageTimeEmbedding,
|
31 |
+
TextTimeEmbedding,
|
32 |
+
TimestepEmbedding,
|
33 |
+
Timesteps,
|
34 |
+
)
|
35 |
+
from .modeling_utils import ModelMixin
|
36 |
+
from .unet_2d_blocks import (
|
37 |
+
CrossAttnDownBlock2D,
|
38 |
+
CrossAttnUpBlock2D,
|
39 |
+
DownBlock2D,
|
40 |
+
UNetMidBlock2DCrossAttn,
|
41 |
+
UNetMidBlock2DSimpleCrossAttn,
|
42 |
+
UpBlock2D,
|
43 |
+
get_down_block,
|
44 |
+
get_up_block,
|
45 |
+
)
|
46 |
+
from .unet_2d_condition import UNet2DConditionOutput
|
47 |
+
|
48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
49 |
+
|
50 |
+
|
51 |
+
class UNet2DConditionGuidedModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
52 |
+
r"""
|
53 |
+
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample,
|
54 |
+
conditional state, and a timestep and returns sample shaped output.
|
55 |
+
|
56 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic
|
57 |
+
methods the library implements for all the models (such as downloading or saving, etc.)
|
58 |
+
|
59 |
+
Parameters:
|
60 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
61 |
+
Height and width of input/output sample.
|
62 |
+
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
|
63 |
+
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
|
64 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
65 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
|
66 |
+
Whether to flip the sin to cos in the time embedding.
|
67 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
68 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to
|
69 |
+
`("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
70 |
+
The tuple of downsample blocks to use.
|
71 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
72 |
+
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`,
|
73 |
+
will skip the mid block layer if `None`.
|
74 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to
|
75 |
+
`("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
76 |
+
The tuple of upsample blocks to use.
|
77 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
78 |
+
Whether to include self-attention in the basic transformer blocks, see
|
79 |
+
[`~models.attention.BasicTransformerBlock`].
|
80 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
81 |
+
The tuple of output channels for each block.
|
82 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
83 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
84 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
85 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
86 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
87 |
+
If `None`, it will skip the normalization and activation layers in post-processing
|
88 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
89 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
90 |
+
The dimension of the cross attention features.
|
91 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
92 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
93 |
+
dimension to `cross_attention_dim`.
|
94 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to None):
|
95 |
+
If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
|
96 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
97 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
98 |
+
num_attention_heads (`int`, *optional*):
|
99 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
100 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
101 |
+
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
102 |
+
class_embed_type (`str`, *optional*, defaults to None):
|
103 |
+
The type of class embedding to use which is ultimately summed with the time embeddings.
|
104 |
+
Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
105 |
+
addition_embed_type (`str`, *optional*, defaults to None):
|
106 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
107 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
108 |
+
num_class_embeds (`int`, *optional*, defaults to None):
|
109 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
110 |
+
class conditioning with `class_embed_type` equal to `None`.
|
111 |
+
time_embedding_type (`str`, *optional*, default to `positional`):
|
112 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
113 |
+
time_embedding_dim (`int`, *optional*, default to `None`):
|
114 |
+
An optional override for the dimension of the projected time embedding.
|
115 |
+
time_embedding_act_fn (`str`, *optional*, default to `None`):
|
116 |
+
Optional activation function to use on the time embeddings only one time before they as passed
|
117 |
+
to the rest of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
118 |
+
timestep_post_act (`str, *optional*, default to `None`):
|
119 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
120 |
+
time_cond_proj_dim (`int`, *optional*, default to `None`):
|
121 |
+
The dimension of `cond_proj` layer in timestep embedding.
|
122 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
123 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
124 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
125 |
+
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
|
126 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
127 |
+
embeddings with the class embeddings.
|
128 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
129 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
130 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
|
131 |
+
`only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`.
|
132 |
+
Else, it will default to `False`.
|
133 |
+
"""
|
134 |
+
|
135 |
+
_supports_gradient_checkpointing = True
|
136 |
+
|
137 |
+
@register_to_config
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
sample_size: Optional[int] = None,
|
141 |
+
in_channels: int = 4,
|
142 |
+
out_channels: int = 4,
|
143 |
+
center_input_sample: bool = False,
|
144 |
+
flip_sin_to_cos: bool = True,
|
145 |
+
freq_shift: int = 0,
|
146 |
+
down_block_types: Tuple[str] = (
|
147 |
+
"CrossAttnDownBlock2D",
|
148 |
+
"CrossAttnDownBlock2D",
|
149 |
+
"CrossAttnDownBlock2D",
|
150 |
+
"DownBlock2D",
|
151 |
+
),
|
152 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
153 |
+
up_block_types: Tuple[str] = (
|
154 |
+
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
155 |
+
),
|
156 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
157 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
158 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
159 |
+
downsample_padding: int = 1,
|
160 |
+
mid_block_scale_factor: float = 1,
|
161 |
+
act_fn: str = "silu",
|
162 |
+
norm_num_groups: Optional[int] = 32,
|
163 |
+
norm_eps: float = 1e-5,
|
164 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
165 |
+
encoder_hid_dim: Optional[int] = None,
|
166 |
+
encoder_hid_dim_type: Optional[str] = None,
|
167 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
168 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
169 |
+
dual_cross_attention: bool = False,
|
170 |
+
use_linear_projection: bool = False,
|
171 |
+
class_embed_type: Optional[str] = None,
|
172 |
+
addition_embed_type: Optional[str] = None,
|
173 |
+
num_class_embeds: Optional[int] = None,
|
174 |
+
upcast_attention: bool = False,
|
175 |
+
resnet_time_scale_shift: str = "default",
|
176 |
+
resnet_skip_time_act: bool = False,
|
177 |
+
resnet_out_scale_factor: int = 1.0,
|
178 |
+
time_embedding_type: str = "positional",
|
179 |
+
time_embedding_dim: Optional[int] = None,
|
180 |
+
time_embedding_act_fn: Optional[str] = None,
|
181 |
+
timestep_post_act: Optional[str] = None,
|
182 |
+
time_cond_proj_dim: Optional[int] = None,
|
183 |
+
guidance_embedding_type: str = "fourier",
|
184 |
+
guidance_embedding_dim: Optional[int] = None,
|
185 |
+
guidance_post_act: Optional[str] = None,
|
186 |
+
guidance_cond_proj_dim: Optional[int] = None,
|
187 |
+
conv_in_kernel: int = 3,
|
188 |
+
conv_out_kernel: int = 3,
|
189 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
190 |
+
class_embeddings_concat: bool = False,
|
191 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
192 |
+
cross_attention_norm: Optional[str] = None,
|
193 |
+
addition_embed_type_num_heads=64,
|
194 |
+
):
|
195 |
+
super().__init__()
|
196 |
+
|
197 |
+
self.sample_size = sample_size
|
198 |
+
|
199 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
200 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
201 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
202 |
+
# when this library was created. The incorrect naming was only discovered much later in
|
203 |
+
# https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
204 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too
|
205 |
+
# backwards breaking which is why we correct for the naming here.
|
206 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
207 |
+
|
208 |
+
# Check inputs
|
209 |
+
if len(down_block_types) != len(up_block_types):
|
210 |
+
raise ValueError(
|
211 |
+
"Must provide the same number of `down_block_types` as `up_block_types`. "
|
212 |
+
f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
213 |
+
)
|
214 |
+
|
215 |
+
if len(block_out_channels) != len(down_block_types):
|
216 |
+
raise ValueError(
|
217 |
+
"Must provide the same number of `block_out_channels` as `down_block_types`. "
|
218 |
+
f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
219 |
+
)
|
220 |
+
|
221 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
222 |
+
raise ValueError(
|
223 |
+
"Must provide the same number of `only_cross_attention` as `down_block_types`. "
|
224 |
+
f"`only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
225 |
+
)
|
226 |
+
|
227 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
228 |
+
raise ValueError(
|
229 |
+
"Must provide the same number of `num_attention_heads` as `down_block_types`. "
|
230 |
+
f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
231 |
+
)
|
232 |
+
|
233 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
234 |
+
raise ValueError(
|
235 |
+
"Must provide the same number of `attention_head_dim` as `down_block_types`. "
|
236 |
+
f"`attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
237 |
+
)
|
238 |
+
|
239 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
240 |
+
raise ValueError(
|
241 |
+
"Must provide the same number of `cross_attention_dim` as `down_block_types`. "
|
242 |
+
f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
243 |
+
)
|
244 |
+
|
245 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
246 |
+
raise ValueError(
|
247 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. "
|
248 |
+
f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
249 |
+
)
|
250 |
+
|
251 |
+
# input
|
252 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
253 |
+
self.conv_in = nn.Conv2d(
|
254 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
255 |
+
)
|
256 |
+
|
257 |
+
# time and guidance embeddings
|
258 |
+
embedding_types = {'time': time_embedding_type, 'guidance': guidance_embedding_type}
|
259 |
+
embedding_dims = {'time': time_embedding_dim, 'guidance': guidance_embedding_dim}
|
260 |
+
embed_dims, embed_input_dims, embed_projs = {}, {}, {}
|
261 |
+
|
262 |
+
for key in ['time', 'guidance']:
|
263 |
+
logger.info(f"Using {embedding_types[key]} embedding for {key}.")
|
264 |
+
|
265 |
+
if embedding_types[key] == "fourier":
|
266 |
+
embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
|
267 |
+
embed_input_dims[key] = embed_dims[key]
|
268 |
+
if embed_dims[key] % 2 != 0:
|
269 |
+
raise ValueError(
|
270 |
+
f"`{key}_embed_dim` should be divisible by 2, but is {embed_dims[key]}."
|
271 |
+
)
|
272 |
+
embed_projs[key] = GaussianFourierProjection(
|
273 |
+
embed_dims[key] // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
274 |
+
)
|
275 |
+
|
276 |
+
elif embedding_types[key] == "positional":
|
277 |
+
embed_dims[key] = embedding_dims[key] or block_out_channels[0] * 4
|
278 |
+
embed_input_dims[key] = block_out_channels[0]
|
279 |
+
embed_projs[key] = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
280 |
+
|
281 |
+
else:
|
282 |
+
raise ValueError(
|
283 |
+
f"{embedding_types[key]} does not exist for {key} embedding. "
|
284 |
+
f"Please make sure to use one of `fourier` or `positional`."
|
285 |
+
)
|
286 |
+
|
287 |
+
self.time_proj, self.guidance_proj = embed_projs['time'], embed_projs['guidance']
|
288 |
+
|
289 |
+
self.time_embedding = TimestepEmbedding(
|
290 |
+
embed_input_dims['time'],
|
291 |
+
embed_dims['time'],
|
292 |
+
act_fn=act_fn,
|
293 |
+
post_act_fn=timestep_post_act,
|
294 |
+
cond_proj_dim=time_cond_proj_dim,
|
295 |
+
)
|
296 |
+
self.guidance_embedding = TimestepEmbedding(
|
297 |
+
embed_input_dims['guidance'],
|
298 |
+
embed_dims['guidance'],
|
299 |
+
act_fn=act_fn,
|
300 |
+
post_act_fn=guidance_post_act,
|
301 |
+
cond_proj_dim=guidance_cond_proj_dim,
|
302 |
+
)
|
303 |
+
|
304 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
305 |
+
encoder_hid_dim_type = "text_proj"
|
306 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
307 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
308 |
+
|
309 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
310 |
+
raise ValueError(
|
311 |
+
"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` "
|
312 |
+
f"is set to {encoder_hid_dim_type}."
|
313 |
+
)
|
314 |
+
|
315 |
+
if encoder_hid_dim_type == "text_proj":
|
316 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
317 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
318 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much,
|
319 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the
|
320 |
+
# currently only use case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
321 |
+
self.encoder_hid_proj = TextImageProjection(
|
322 |
+
text_embed_dim=encoder_hid_dim,
|
323 |
+
image_embed_dim=cross_attention_dim,
|
324 |
+
cross_attention_dim=cross_attention_dim,
|
325 |
+
)
|
326 |
+
|
327 |
+
elif encoder_hid_dim_type is not None:
|
328 |
+
raise ValueError(
|
329 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
self.encoder_hid_proj = None
|
333 |
+
|
334 |
+
# class embedding
|
335 |
+
# print(f"class_embed_type: {class_embed_type}, num_class_embeds: {num_class_embeds}")
|
336 |
+
if class_embed_type is None and num_class_embeds is not None:
|
337 |
+
self.class_embedding = nn.Embedding(num_class_embeds, embedding_dims['time'])
|
338 |
+
elif class_embed_type == "timestep":
|
339 |
+
self.class_embedding = TimestepEmbedding(
|
340 |
+
embed_input_dims['time'], embed_dims['time'], act_fn=act_fn
|
341 |
+
)
|
342 |
+
elif class_embed_type == "identity":
|
343 |
+
self.class_embedding = nn.Identity(embed_dims['time'], embed_dims['time'])
|
344 |
+
elif class_embed_type == "projection":
|
345 |
+
if projection_class_embeddings_input_dim is None:
|
346 |
+
raise ValueError(
|
347 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
348 |
+
)
|
349 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
350 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
351 |
+
# 2. it projects from an arbitrary input dimension.
|
352 |
+
#
|
353 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
354 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal
|
355 |
+
# embeddings. As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
356 |
+
self.class_embedding = TimestepEmbedding(
|
357 |
+
projection_class_embeddings_input_dim, embed_dims['time']
|
358 |
+
)
|
359 |
+
elif class_embed_type == "simple_projection":
|
360 |
+
if projection_class_embeddings_input_dim is None:
|
361 |
+
raise ValueError(
|
362 |
+
"`class_embed_type`: 'simple_projection' requires "
|
363 |
+
"`projection_class_embeddings_input_dim` be set"
|
364 |
+
)
|
365 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, embed_dims['time'])
|
366 |
+
else:
|
367 |
+
self.class_embedding = None
|
368 |
+
|
369 |
+
# Addition embedding
|
370 |
+
if addition_embed_type == "text":
|
371 |
+
if encoder_hid_dim is not None:
|
372 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
373 |
+
else:
|
374 |
+
text_time_embedding_from_dim = cross_attention_dim
|
375 |
+
|
376 |
+
self.add_embedding = TextTimeEmbedding(
|
377 |
+
text_time_embedding_from_dim, embed_dims['time'], num_heads=addition_embed_type_num_heads
|
378 |
+
)
|
379 |
+
elif addition_embed_type == "text_image":
|
380 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`.
|
381 |
+
# To not clutter the __init__ too much, they are set to `cross_attention_dim`
|
382 |
+
# here as this is exactly the required dimension for the currently only use case
|
383 |
+
# when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
384 |
+
self.add_embedding = TextImageTimeEmbedding(
|
385 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim,
|
386 |
+
time_embed_dim=embed_dims['time']
|
387 |
+
)
|
388 |
+
elif addition_embed_type is not None:
|
389 |
+
raise ValueError(
|
390 |
+
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
391 |
+
)
|
392 |
+
|
393 |
+
# Embedding activation function
|
394 |
+
if time_embedding_act_fn is None:
|
395 |
+
self.time_embed_act = None
|
396 |
+
else:
|
397 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
398 |
+
|
399 |
+
self.down_blocks = nn.ModuleList([])
|
400 |
+
self.up_blocks = nn.ModuleList([])
|
401 |
+
|
402 |
+
if isinstance(only_cross_attention, bool):
|
403 |
+
if mid_block_only_cross_attention is None:
|
404 |
+
mid_block_only_cross_attention = only_cross_attention
|
405 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
406 |
+
|
407 |
+
if mid_block_only_cross_attention is None:
|
408 |
+
mid_block_only_cross_attention = False
|
409 |
+
|
410 |
+
if isinstance(num_attention_heads, int):
|
411 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
412 |
+
|
413 |
+
if isinstance(attention_head_dim, int):
|
414 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
415 |
+
|
416 |
+
if isinstance(cross_attention_dim, int):
|
417 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
418 |
+
|
419 |
+
if isinstance(layers_per_block, int):
|
420 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
421 |
+
|
422 |
+
if class_embeddings_concat:
|
423 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
424 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of
|
425 |
+
# the regular time embeddings
|
426 |
+
# Now we have time emb, guidance emb, and class emb
|
427 |
+
blocks_time_embed_dim = embed_dims['time'] * 3
|
428 |
+
else:
|
429 |
+
blocks_time_embed_dim = embed_dims['time']
|
430 |
+
|
431 |
+
# down
|
432 |
+
output_channel = block_out_channels[0]
|
433 |
+
for i, down_block_type in enumerate(down_block_types):
|
434 |
+
input_channel = output_channel
|
435 |
+
output_channel = block_out_channels[i]
|
436 |
+
is_final_block = i == len(block_out_channels) - 1
|
437 |
+
|
438 |
+
down_block = get_down_block(
|
439 |
+
down_block_type,
|
440 |
+
num_layers=layers_per_block[i],
|
441 |
+
in_channels=input_channel,
|
442 |
+
out_channels=output_channel,
|
443 |
+
temb_channels=blocks_time_embed_dim,
|
444 |
+
add_downsample=not is_final_block,
|
445 |
+
resnet_eps=norm_eps,
|
446 |
+
resnet_act_fn=act_fn,
|
447 |
+
resnet_groups=norm_num_groups,
|
448 |
+
cross_attention_dim=cross_attention_dim[i],
|
449 |
+
num_attention_heads=num_attention_heads[i],
|
450 |
+
downsample_padding=downsample_padding,
|
451 |
+
dual_cross_attention=dual_cross_attention,
|
452 |
+
use_linear_projection=use_linear_projection,
|
453 |
+
only_cross_attention=only_cross_attention[i],
|
454 |
+
upcast_attention=upcast_attention,
|
455 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
456 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
457 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
458 |
+
cross_attention_norm=cross_attention_norm,
|
459 |
+
attention_head_dim=\
|
460 |
+
attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
461 |
+
)
|
462 |
+
self.down_blocks.append(down_block)
|
463 |
+
|
464 |
+
# mid
|
465 |
+
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
466 |
+
self.mid_block = UNetMidBlock2DCrossAttn(
|
467 |
+
in_channels=block_out_channels[-1],
|
468 |
+
temb_channels=blocks_time_embed_dim,
|
469 |
+
resnet_eps=norm_eps,
|
470 |
+
resnet_act_fn=act_fn,
|
471 |
+
output_scale_factor=mid_block_scale_factor,
|
472 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
473 |
+
cross_attention_dim=cross_attention_dim[-1],
|
474 |
+
num_attention_heads=num_attention_heads[-1],
|
475 |
+
resnet_groups=norm_num_groups,
|
476 |
+
dual_cross_attention=dual_cross_attention,
|
477 |
+
use_linear_projection=use_linear_projection,
|
478 |
+
upcast_attention=upcast_attention,
|
479 |
+
)
|
480 |
+
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
481 |
+
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
482 |
+
in_channels=block_out_channels[-1],
|
483 |
+
temb_channels=blocks_time_embed_dim,
|
484 |
+
resnet_eps=norm_eps,
|
485 |
+
resnet_act_fn=act_fn,
|
486 |
+
output_scale_factor=mid_block_scale_factor,
|
487 |
+
cross_attention_dim=cross_attention_dim[-1],
|
488 |
+
attention_head_dim=attention_head_dim[-1],
|
489 |
+
resnet_groups=norm_num_groups,
|
490 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
491 |
+
skip_time_act=resnet_skip_time_act,
|
492 |
+
only_cross_attention=mid_block_only_cross_attention,
|
493 |
+
cross_attention_norm=cross_attention_norm,
|
494 |
+
)
|
495 |
+
elif mid_block_type is None:
|
496 |
+
self.mid_block = None
|
497 |
+
else:
|
498 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
499 |
+
|
500 |
+
# count how many layers upsample the images
|
501 |
+
self.num_upsamplers = 0
|
502 |
+
|
503 |
+
# up
|
504 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
505 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
506 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
507 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
508 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
509 |
+
|
510 |
+
output_channel = reversed_block_out_channels[0]
|
511 |
+
for i, up_block_type in enumerate(up_block_types):
|
512 |
+
is_final_block = i == len(block_out_channels) - 1
|
513 |
+
|
514 |
+
prev_output_channel = output_channel
|
515 |
+
output_channel = reversed_block_out_channels[i]
|
516 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
517 |
+
|
518 |
+
# add upsample block for all BUT final layer
|
519 |
+
if not is_final_block:
|
520 |
+
add_upsample = True
|
521 |
+
self.num_upsamplers += 1
|
522 |
+
else:
|
523 |
+
add_upsample = False
|
524 |
+
|
525 |
+
up_block = get_up_block(
|
526 |
+
up_block_type,
|
527 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
528 |
+
in_channels=input_channel,
|
529 |
+
out_channels=output_channel,
|
530 |
+
prev_output_channel=prev_output_channel,
|
531 |
+
temb_channels=blocks_time_embed_dim,
|
532 |
+
add_upsample=add_upsample,
|
533 |
+
resnet_eps=norm_eps,
|
534 |
+
resnet_act_fn=act_fn,
|
535 |
+
resnet_groups=norm_num_groups,
|
536 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
537 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
538 |
+
dual_cross_attention=dual_cross_attention,
|
539 |
+
use_linear_projection=use_linear_projection,
|
540 |
+
only_cross_attention=only_cross_attention[i],
|
541 |
+
upcast_attention=upcast_attention,
|
542 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
543 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
544 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
545 |
+
cross_attention_norm=cross_attention_norm,
|
546 |
+
attention_head_dim=\
|
547 |
+
attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
548 |
+
)
|
549 |
+
self.up_blocks.append(up_block)
|
550 |
+
prev_output_channel = output_channel
|
551 |
+
|
552 |
+
# out
|
553 |
+
if norm_num_groups is not None:
|
554 |
+
self.conv_norm_out = nn.GroupNorm(
|
555 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
556 |
+
)
|
557 |
+
self.conv_act = get_activation(act_fn)
|
558 |
+
|
559 |
+
else:
|
560 |
+
self.conv_norm_out = None
|
561 |
+
self.conv_act = None
|
562 |
+
|
563 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
564 |
+
self.conv_out = nn.Conv2d(
|
565 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
566 |
+
)
|
567 |
+
|
568 |
+
@property
|
569 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
570 |
+
r"""
|
571 |
+
Returns:
|
572 |
+
`dict` of attention processors: A dictionary containing all attention processors used in
|
573 |
+
the model with indexed by its weight name.
|
574 |
+
"""
|
575 |
+
# set recursively
|
576 |
+
processors = {}
|
577 |
+
|
578 |
+
def fn_recursive_add_processors(
|
579 |
+
name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]
|
580 |
+
):
|
581 |
+
if hasattr(module, "set_processor"):
|
582 |
+
processors[f"{name}.processor"] = module.processor
|
583 |
+
|
584 |
+
for sub_name, child in module.named_children():
|
585 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
586 |
+
|
587 |
+
return processors
|
588 |
+
|
589 |
+
for name, module in self.named_children():
|
590 |
+
fn_recursive_add_processors(name, module, processors)
|
591 |
+
|
592 |
+
return processors
|
593 |
+
|
594 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
595 |
+
r"""
|
596 |
+
Parameters:
|
597 |
+
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
|
598 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the
|
599 |
+
processor of **all** `Attention` layers.
|
600 |
+
In case `processor` is a dict, the key needs to define the path to the corresponding cross
|
601 |
+
attention processor. This is strongly recommended when setting trainable attention processors.
|
602 |
+
"""
|
603 |
+
count = len(self.attn_processors.keys())
|
604 |
+
|
605 |
+
if isinstance(processor, dict) and len(processor) != count:
|
606 |
+
raise ValueError(
|
607 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match"
|
608 |
+
f" the number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
609 |
+
)
|
610 |
+
|
611 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
612 |
+
if hasattr(module, "set_processor"):
|
613 |
+
if not isinstance(processor, dict):
|
614 |
+
module.set_processor(processor)
|
615 |
+
else:
|
616 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
617 |
+
|
618 |
+
for sub_name, child in module.named_children():
|
619 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
620 |
+
|
621 |
+
for name, module in self.named_children():
|
622 |
+
fn_recursive_attn_processor(name, module, processor)
|
623 |
+
|
624 |
+
def set_default_attn_processor(self):
|
625 |
+
"""
|
626 |
+
Disables custom attention processors and sets the default attention implementation.
|
627 |
+
"""
|
628 |
+
self.set_attn_processor(AttnProcessor())
|
629 |
+
|
630 |
+
def set_attention_slice(self, slice_size):
|
631 |
+
r"""
|
632 |
+
Enable sliced attention computation.
|
633 |
+
|
634 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute
|
635 |
+
attention in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
636 |
+
|
637 |
+
Args:
|
638 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
639 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two
|
640 |
+
steps. "max"`, maximum amount of memory will be saved by running only one slice at a time.
|
641 |
+
If a number is provided, uses as many slices as `num_attention_heads // slice_size`.
|
642 |
+
In this case, `num_attention_heads` must be a multiple of `slice_size`.
|
643 |
+
"""
|
644 |
+
sliceable_head_dims = []
|
645 |
+
|
646 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
647 |
+
if hasattr(module, "set_attention_slice"):
|
648 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
649 |
+
|
650 |
+
for child in module.children():
|
651 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
652 |
+
|
653 |
+
# retrieve number of attention layers
|
654 |
+
for module in self.children():
|
655 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
656 |
+
|
657 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
658 |
+
|
659 |
+
if slice_size == "auto":
|
660 |
+
# half the attention head size is usually a good trade-off between
|
661 |
+
# speed and memory
|
662 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
663 |
+
elif slice_size == "max":
|
664 |
+
# make smallest slice possible
|
665 |
+
slice_size = num_sliceable_layers * [1]
|
666 |
+
|
667 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
668 |
+
|
669 |
+
if len(slice_size) != len(sliceable_head_dims):
|
670 |
+
raise ValueError(
|
671 |
+
f"You have provided {len(slice_size)}, but {self.config} has "
|
672 |
+
f"{len(sliceable_head_dims)} different attention layers. "
|
673 |
+
f"Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
674 |
+
)
|
675 |
+
|
676 |
+
for i in range(len(slice_size)):
|
677 |
+
size = slice_size[i]
|
678 |
+
dim = sliceable_head_dims[i]
|
679 |
+
if size is not None and size > dim:
|
680 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
681 |
+
|
682 |
+
# Recursively walk through all the children.
|
683 |
+
# Any children which exposes the set_attention_slice method
|
684 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
685 |
+
if hasattr(module, "set_attention_slice"):
|
686 |
+
module.set_attention_slice(slice_size.pop())
|
687 |
+
|
688 |
+
for child in module.children():
|
689 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
690 |
+
|
691 |
+
reversed_slice_size = list(reversed(slice_size))
|
692 |
+
for module in self.children():
|
693 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
694 |
+
|
695 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
696 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
697 |
+
module.gradient_checkpointing = value
|
698 |
+
|
699 |
+
def _prepare_tensor(self, value, device):
|
700 |
+
if not torch.is_tensor(value):
|
701 |
+
# Requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
702 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
703 |
+
if isinstance(value, float):
|
704 |
+
dtype = torch.float32 if device.type == "mps" else torch.float64
|
705 |
+
else:
|
706 |
+
dtype = torch.int32 if device.type == "mps" else torch.int64
|
707 |
+
|
708 |
+
return torch.tensor([value], dtype=dtype, device=device)
|
709 |
+
|
710 |
+
elif len(value.shape) == 0:
|
711 |
+
return value[None].to(device)
|
712 |
+
|
713 |
+
else:
|
714 |
+
return value
|
715 |
+
|
716 |
+
def forward(
|
717 |
+
self,
|
718 |
+
sample: torch.FloatTensor,
|
719 |
+
timestep: Union[torch.Tensor, float, int],
|
720 |
+
guidance: Union[torch.Tensor, float, int],
|
721 |
+
encoder_hidden_states: torch.Tensor,
|
722 |
+
class_labels: Optional[torch.Tensor] = None,
|
723 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
724 |
+
guidance_cond: Optional[torch.Tensor] = None,
|
725 |
+
attention_mask: Optional[torch.Tensor] = None,
|
726 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
727 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
728 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
729 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
730 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
731 |
+
return_dict: bool = True,
|
732 |
+
**kwargs
|
733 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
734 |
+
r"""
|
735 |
+
Args:
|
736 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
737 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
738 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
739 |
+
(batch, sequence_length, feature_dim) encoder hidden states
|
740 |
+
encoder_attention_mask (`torch.Tensor`):
|
741 |
+
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep,
|
742 |
+
False = discard. Mask will be converted into a bias, which adds large negative values to
|
743 |
+
attention scores corresponding to "discard" tokens.
|
744 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
745 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`]
|
746 |
+
instead of a plain tuple.
|
747 |
+
cross_attention_kwargs (`dict`, *optional*):
|
748 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined
|
749 |
+
under `self.processor` in [diffusers.cross_attention]
|
750 |
+
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
751 |
+
added_cond_kwargs (`dict`, *optional*):
|
752 |
+
A kwargs dictionary that if specified includes additonal conditions that can be used for
|
753 |
+
additonal time embeddings or encoder hidden states projections. See the configurations
|
754 |
+
`encoder_hid_dim_type` and `addition_embed_type` for more information.
|
755 |
+
|
756 |
+
Returns:
|
757 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
758 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
|
759 |
+
When returning a tuple, the first element is the sample tensor.
|
760 |
+
"""
|
761 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
762 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
763 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
764 |
+
# on the fly if necessary.
|
765 |
+
default_overall_up_factor = 2 ** self.num_upsamplers
|
766 |
+
|
767 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
768 |
+
forward_upsample_size = False
|
769 |
+
upsample_size = None
|
770 |
+
|
771 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
772 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
773 |
+
forward_upsample_size = True
|
774 |
+
|
775 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
776 |
+
# expects mask of shape:
|
777 |
+
# [batch, key_tokens]
|
778 |
+
# adds singleton query_tokens dimension:
|
779 |
+
# [batch, 1, key_tokens]
|
780 |
+
# this helps to broadcast it as a bias over attention scores,
|
781 |
+
# which will be in one of the following shapes:
|
782 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
783 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
784 |
+
if attention_mask is not None:
|
785 |
+
# assume that mask is expressed as:
|
786 |
+
# (1 = keep, 0 = discard)
|
787 |
+
# convert mask into a bias that can be added to attention scores:
|
788 |
+
# (keep = +0, discard = -10000.0)
|
789 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
790 |
+
attention_mask = attention_mask.unsqueeze(1)
|
791 |
+
|
792 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
793 |
+
if encoder_attention_mask is not None:
|
794 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * (-10000.0)
|
795 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
796 |
+
|
797 |
+
# 0. center input if necessary
|
798 |
+
if self.config.center_input_sample:
|
799 |
+
sample = 2 * sample - 1.0
|
800 |
+
|
801 |
+
# 1. time and guidance
|
802 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
803 |
+
timestep = self._prepare_tensor(timestep, sample.device).expand(sample.shape[0])
|
804 |
+
# Project to get embedding
|
805 |
+
# `Timestep` does not contain any weights and will always return fp32 tensors
|
806 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
807 |
+
t_emb = self.time_proj(timestep).to(dtype=sample.dtype)
|
808 |
+
t_emb = self.time_embedding(t_emb, timestep_cond)
|
809 |
+
|
810 |
+
guidance = self._prepare_tensor(guidance, sample.device).expand(sample.shape[0])
|
811 |
+
g_emb = self.guidance_proj(guidance).to(dtype=sample.dtype)
|
812 |
+
g_emb = self.guidance_embedding(g_emb, guidance_cond)
|
813 |
+
|
814 |
+
# 1.5. prepare other embeddings
|
815 |
+
if self.class_embedding is None:
|
816 |
+
emb = t_emb + g_emb
|
817 |
+
else:
|
818 |
+
if class_labels is None:
|
819 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
820 |
+
if self.config.class_embed_type == "timestep":
|
821 |
+
class_labels = self.time_proj(class_labels).to(dtype=sample.dtype)
|
822 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
823 |
+
|
824 |
+
if self.config.class_embeddings_concat:
|
825 |
+
emb = torch.cat([t_emb, g_emb, class_emb], dim=-1)
|
826 |
+
else:
|
827 |
+
emb = t_emb + g_emb + class_emb
|
828 |
+
|
829 |
+
if self.config.addition_embed_type == "text":
|
830 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
831 |
+
emb = emb + aug_emb
|
832 |
+
elif self.config.addition_embed_type == "text_image":
|
833 |
+
# Kadinsky 2.1 - style
|
834 |
+
if "image_embeds" not in added_cond_kwargs:
|
835 |
+
raise ValueError(
|
836 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' "
|
837 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
838 |
+
)
|
839 |
+
|
840 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
841 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
842 |
+
|
843 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
844 |
+
emb = emb + aug_emb
|
845 |
+
|
846 |
+
if self.time_embed_act is not None:
|
847 |
+
emb = self.time_embed_act(emb)
|
848 |
+
|
849 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
850 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
851 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
852 |
+
# Kadinsky 2.1 - style
|
853 |
+
if "image_embeds" not in added_cond_kwargs:
|
854 |
+
raise ValueError(
|
855 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
|
856 |
+
"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
857 |
+
)
|
858 |
+
|
859 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
860 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
861 |
+
|
862 |
+
# 2. pre-process
|
863 |
+
sample = self.conv_in(sample)
|
864 |
+
|
865 |
+
# 3. down
|
866 |
+
down_block_res_samples = (sample,)
|
867 |
+
for downsample_block in self.down_blocks:
|
868 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
869 |
+
sample, res_samples = downsample_block(
|
870 |
+
hidden_states=sample,
|
871 |
+
temb=emb,
|
872 |
+
encoder_hidden_states=encoder_hidden_states,
|
873 |
+
attention_mask=attention_mask,
|
874 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
875 |
+
encoder_attention_mask=encoder_attention_mask,
|
876 |
+
)
|
877 |
+
else:
|
878 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
879 |
+
|
880 |
+
down_block_res_samples += res_samples
|
881 |
+
|
882 |
+
if down_block_additional_residuals is not None:
|
883 |
+
new_down_block_res_samples = ()
|
884 |
+
|
885 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
886 |
+
down_block_res_samples, down_block_additional_residuals
|
887 |
+
):
|
888 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
889 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
890 |
+
|
891 |
+
down_block_res_samples = new_down_block_res_samples
|
892 |
+
|
893 |
+
# 4. mid
|
894 |
+
if self.mid_block is not None:
|
895 |
+
sample = self.mid_block(
|
896 |
+
sample,
|
897 |
+
emb,
|
898 |
+
encoder_hidden_states=encoder_hidden_states,
|
899 |
+
attention_mask=attention_mask,
|
900 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
901 |
+
encoder_attention_mask=encoder_attention_mask,
|
902 |
+
)
|
903 |
+
|
904 |
+
if mid_block_additional_residual is not None:
|
905 |
+
sample = sample + mid_block_additional_residual
|
906 |
+
|
907 |
+
# 5. up
|
908 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
909 |
+
is_final_block = i == len(self.up_blocks) - 1
|
910 |
+
|
911 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
912 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
913 |
+
|
914 |
+
# if we have not reached the final block and need to forward the
|
915 |
+
# upsample size, we do it here
|
916 |
+
if not is_final_block and forward_upsample_size:
|
917 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
918 |
+
|
919 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
920 |
+
sample = upsample_block(
|
921 |
+
hidden_states=sample,
|
922 |
+
temb=emb,
|
923 |
+
res_hidden_states_tuple=res_samples,
|
924 |
+
encoder_hidden_states=encoder_hidden_states,
|
925 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
926 |
+
upsample_size=upsample_size,
|
927 |
+
attention_mask=attention_mask,
|
928 |
+
encoder_attention_mask=encoder_attention_mask,
|
929 |
+
)
|
930 |
+
else:
|
931 |
+
sample = upsample_block(
|
932 |
+
hidden_states=sample, temb=emb,
|
933 |
+
res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
934 |
+
)
|
935 |
+
|
936 |
+
# 6. post-process
|
937 |
+
if self.conv_norm_out:
|
938 |
+
sample = self.conv_norm_out(sample)
|
939 |
+
sample = self.conv_act(sample)
|
940 |
+
sample = self.conv_out(sample)
|
941 |
+
|
942 |
+
if not return_dict:
|
943 |
+
return (sample,)
|
944 |
+
|
945 |
+
return UNet2DConditionOutput(sample=sample)
|
diffusers/scheduling_heun_discrete.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
### This file has been modified for the purposes of the ConsistencyTTA generation. ###
|
16 |
+
|
17 |
+
import math
|
18 |
+
from typing import List, Optional, Tuple, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
|
23 |
+
from .utils.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from .utils.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
25 |
+
|
26 |
+
|
27 |
+
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
28 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
29 |
+
"""
|
30 |
+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines
|
31 |
+
the cumulative product of (1-beta) over time from t = [0,1].
|
32 |
+
|
33 |
+
Contains a function alpha_bar that takes an argument t and transforms it to the
|
34 |
+
cumulative product of (1-beta) up to that part of the diffusion process.
|
35 |
+
|
36 |
+
|
37 |
+
Args:
|
38 |
+
num_diffusion_timesteps (`int`): the number of betas to produce.
|
39 |
+
max_beta (`float`):
|
40 |
+
the maximum beta to use; use values lower than 1 to prevent singularities.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
44 |
+
"""
|
45 |
+
|
46 |
+
def alpha_bar(time_step):
|
47 |
+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
48 |
+
|
49 |
+
betas = []
|
50 |
+
for i in range(num_diffusion_timesteps):
|
51 |
+
t1 = i / num_diffusion_timesteps
|
52 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
53 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
54 |
+
return torch.tensor(betas, dtype=torch.float32)
|
55 |
+
|
56 |
+
|
57 |
+
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
58 |
+
"""
|
59 |
+
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules.
|
60 |
+
Based on the original k-diffusion implementation by Katherine Crowson:
|
61 |
+
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/
|
62 |
+
k_diffusion/sampling.py#L90
|
63 |
+
|
64 |
+
[`~ConfigMixin`] takes care of storing all config attributes that are passed
|
65 |
+
in the scheduler's `__init__` function, such as `num_train_timesteps`.
|
66 |
+
They can be accessed via `scheduler.config.num_train_timesteps`.
|
67 |
+
[`SchedulerMixin`] provides general loading and saving functionality via the
|
68 |
+
[`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
num_train_timesteps (`int`):
|
72 |
+
number of diffusion steps used to train the model.
|
73 |
+
beta_start (`float`):
|
74 |
+
the starting `beta` value of inference.
|
75 |
+
beta_end (`float`):
|
76 |
+
the final `beta` value.
|
77 |
+
beta_schedule (`str`):
|
78 |
+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping
|
79 |
+
the model. Choose from `linear` or `scaled_linear`.
|
80 |
+
trained_betas (`np.ndarray`, optional):
|
81 |
+
option to pass an array of betas directly to the constructor to bypass
|
82 |
+
`beta_start`, `beta_end` etc.
|
83 |
+
options to clip the variance used when adding noise to the denoised sample.
|
84 |
+
Choose from `fixed_small`, `fixed_small_log`, `fixed_large`,
|
85 |
+
`fixed_large_log`, `learned` or `learned_range`.
|
86 |
+
prediction_type (`str`, default `epsilon`, optional):
|
87 |
+
prediction type of the scheduler function, one of
|
88 |
+
`epsilon` (predicting the noise of the diffusion process),
|
89 |
+
`sample` (directly predicting the noisy sample`), or
|
90 |
+
`v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
|
91 |
+
"""
|
92 |
+
|
93 |
+
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
94 |
+
order = 2
|
95 |
+
|
96 |
+
@register_to_config
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
num_train_timesteps: int = 1000,
|
100 |
+
beta_start: float = 0.00085, # sensible defaults
|
101 |
+
beta_end: float = 0.012,
|
102 |
+
beta_schedule: str = "linear",
|
103 |
+
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
104 |
+
prediction_type: str = "epsilon",
|
105 |
+
use_karras_sigmas: Optional[bool] = False,
|
106 |
+
):
|
107 |
+
if trained_betas is not None:
|
108 |
+
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
109 |
+
elif beta_schedule == "linear":
|
110 |
+
self.betas = torch.linspace(
|
111 |
+
beta_start, beta_end, num_train_timesteps, dtype=torch.float32
|
112 |
+
)
|
113 |
+
elif beta_schedule == "scaled_linear":
|
114 |
+
# this schedule is very specific to the latent diffusion model.
|
115 |
+
self.betas = (
|
116 |
+
torch.linspace(
|
117 |
+
beta_start ** 0.5, beta_end ** 0.5,
|
118 |
+
num_train_timesteps, dtype=torch.float32
|
119 |
+
) ** 2
|
120 |
+
)
|
121 |
+
elif beta_schedule == "squaredcos_cap_v2":
|
122 |
+
# Glide cosine schedule
|
123 |
+
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(
|
126 |
+
f"{beta_schedule} does is not implemented for {self.__class__}"
|
127 |
+
)
|
128 |
+
|
129 |
+
self.alphas = 1.0 - self.betas
|
130 |
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
131 |
+
|
132 |
+
# set all values
|
133 |
+
self.use_karras_sigmas = use_karras_sigmas
|
134 |
+
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
135 |
+
|
136 |
+
def index_for_timestep(self, timestep):
|
137 |
+
"""Get the first / last index at which self.timesteps == timestep
|
138 |
+
"""
|
139 |
+
assert len(timestep.shape) < 2
|
140 |
+
avail_timesteps = self.timesteps.reshape(1, -1).to(timestep.device)
|
141 |
+
mask = (avail_timesteps == timestep.reshape(-1, 1))
|
142 |
+
assert (mask.sum(dim=1) != 0).all(), f"timestep: {timestep.tolist()}"
|
143 |
+
mask = mask.cpu() * torch.arange(mask.shape[1]).reshape(1, -1)
|
144 |
+
|
145 |
+
if self.state_in_first_order:
|
146 |
+
return mask.argmax(dim=1).numpy()
|
147 |
+
else:
|
148 |
+
return mask.argmax(dim=1).numpy() - 1
|
149 |
+
|
150 |
+
def scale_model_input(
|
151 |
+
self,
|
152 |
+
sample: torch.FloatTensor,
|
153 |
+
timestep: Union[float, torch.FloatTensor],
|
154 |
+
) -> torch.FloatTensor:
|
155 |
+
"""
|
156 |
+
Ensures interchangeability with schedulers that need to scale the
|
157 |
+
denoising model input depending on the current timestep.
|
158 |
+
Args:
|
159 |
+
sample (`torch.FloatTensor`): input sample
|
160 |
+
timestep (`int`, optional): current timestep
|
161 |
+
Returns:
|
162 |
+
`torch.FloatTensor`: scaled input sample
|
163 |
+
"""
|
164 |
+
if not torch.is_tensor(timestep):
|
165 |
+
timestep = torch.tensor(timestep)
|
166 |
+
timestep = timestep.to(sample.device).reshape(-1)
|
167 |
+
step_index = self.index_for_timestep(timestep)
|
168 |
+
|
169 |
+
sigma = self.sigmas[step_index].reshape(-1, 1, 1, 1).to(sample.device)
|
170 |
+
sample = sample / ((sigma ** 2 + 1) ** 0.5) # sample *= sqrt_alpha_prod
|
171 |
+
return sample
|
172 |
+
|
173 |
+
def set_timesteps(
|
174 |
+
self,
|
175 |
+
num_inference_steps: int,
|
176 |
+
device: Union[str, torch.device] = None,
|
177 |
+
num_train_timesteps: Optional[int] = None,
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Sets the timesteps used for the diffusion chain.
|
181 |
+
Supporting function to be run before inference.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
num_inference_steps (`int`):
|
185 |
+
the number of diffusion steps used when generating samples
|
186 |
+
with a pre-trained model.
|
187 |
+
device (`str` or `torch.device`, optional):
|
188 |
+
the device to which the timesteps should be moved to.
|
189 |
+
If `None`, the timesteps are not moved.
|
190 |
+
"""
|
191 |
+
self.num_inference_steps = num_inference_steps
|
192 |
+
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
193 |
+
|
194 |
+
timesteps = np.linspace(
|
195 |
+
0, num_train_timesteps - 1, num_inference_steps, dtype=float
|
196 |
+
)[::-1].copy()
|
197 |
+
|
198 |
+
# sigma^2 = beta / alpha
|
199 |
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
200 |
+
log_sigmas = np.log(sigmas)
|
201 |
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
202 |
+
|
203 |
+
if self.use_karras_sigmas:
|
204 |
+
sigmas = self._convert_to_karras(
|
205 |
+
in_sigmas=sigmas, num_inference_steps=self.num_inference_steps
|
206 |
+
)
|
207 |
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
208 |
+
|
209 |
+
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
210 |
+
sigmas = torch.from_numpy(sigmas).to(device=device)
|
211 |
+
self.sigmas = torch.cat(
|
212 |
+
[sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
|
213 |
+
)
|
214 |
+
|
215 |
+
# standard deviation of the initial noise distribution
|
216 |
+
self.init_noise_sigma = self.sigmas.max()
|
217 |
+
|
218 |
+
timesteps = torch.from_numpy(timesteps)
|
219 |
+
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
220 |
+
if 'mps' in str(device):
|
221 |
+
timesteps = timesteps.float()
|
222 |
+
self.timesteps = timesteps.to(device)
|
223 |
+
|
224 |
+
# empty dt and derivative
|
225 |
+
self.prev_derivative = None
|
226 |
+
self.dt = None
|
227 |
+
|
228 |
+
def _sigma_to_t(self, sigma, log_sigmas):
|
229 |
+
# get log sigma
|
230 |
+
log_sigma = np.log(sigma)
|
231 |
+
|
232 |
+
# get distribution
|
233 |
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
234 |
+
|
235 |
+
# get sigmas range
|
236 |
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(
|
237 |
+
max=log_sigmas.shape[0] - 2
|
238 |
+
)
|
239 |
+
high_idx = low_idx + 1
|
240 |
+
|
241 |
+
low = log_sigmas[low_idx]
|
242 |
+
high = log_sigmas[high_idx]
|
243 |
+
|
244 |
+
# interpolate sigmas
|
245 |
+
w = (low - log_sigma) / (low - high)
|
246 |
+
w = np.clip(w, 0, 1)
|
247 |
+
|
248 |
+
# transform interpolation to time range
|
249 |
+
t = (1 - w) * low_idx + w * high_idx
|
250 |
+
t = t.reshape(sigma.shape)
|
251 |
+
return t
|
252 |
+
|
253 |
+
def _convert_to_karras(
|
254 |
+
self, in_sigmas: torch.FloatTensor, num_inference_steps
|
255 |
+
) -> torch.FloatTensor:
|
256 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
257 |
+
|
258 |
+
sigma_min: float = in_sigmas[-1].item()
|
259 |
+
sigma_max: float = in_sigmas[0].item()
|
260 |
+
|
261 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
262 |
+
ramp = np.linspace(0, 1, num_inference_steps)
|
263 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
264 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
265 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
266 |
+
return sigmas
|
267 |
+
|
268 |
+
@property
|
269 |
+
def state_in_first_order(self):
|
270 |
+
return self.dt is None
|
271 |
+
|
272 |
+
def step(
|
273 |
+
self,
|
274 |
+
model_output: Union[torch.FloatTensor, np.ndarray],
|
275 |
+
timestep: Union[float, torch.FloatTensor],
|
276 |
+
sample: Union[torch.FloatTensor, np.ndarray],
|
277 |
+
return_dict: bool = True,
|
278 |
+
) -> Union[SchedulerOutput, Tuple]:
|
279 |
+
"""
|
280 |
+
Predict the sample at the previous timestep by reversing the SDE.
|
281 |
+
Core function to propagate the diffusion process from the learned
|
282 |
+
model outputs (most often the predicted noise).
|
283 |
+
Args:
|
284 |
+
model_output (`torch.FloatTensor` or `np.ndarray`):
|
285 |
+
direct output from learned diffusion model.
|
286 |
+
timestep (`int`):
|
287 |
+
current discrete timestep in the diffusion chain.
|
288 |
+
sample (`torch.FloatTensor` or `np.ndarray`):
|
289 |
+
current instance of sample being created by diffusion process.
|
290 |
+
return_dict (`bool`):
|
291 |
+
option for returning tuple rather than SchedulerOutput class
|
292 |
+
Returns:
|
293 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
294 |
+
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict`
|
295 |
+
is True, otherwise a `tuple`. When returning a tuple,
|
296 |
+
the first element is the sample tensor.
|
297 |
+
"""
|
298 |
+
if not torch.is_tensor(timestep):
|
299 |
+
timestep = torch.tensor(timestep)
|
300 |
+
timestep = timestep.reshape(-1).to(sample.device)
|
301 |
+
step_index = self.index_for_timestep(timestep)
|
302 |
+
|
303 |
+
if self.state_in_first_order:
|
304 |
+
sigma = self.sigmas[step_index]
|
305 |
+
sigma_next = self.sigmas[step_index + 1]
|
306 |
+
else:
|
307 |
+
# 2nd order / Heun's method
|
308 |
+
sigma = self.sigmas[step_index - 1]
|
309 |
+
sigma_next = self.sigmas[step_index]
|
310 |
+
|
311 |
+
sigma = sigma.reshape(-1, 1, 1, 1).to(sample.device)
|
312 |
+
sigma_next = sigma_next.reshape(-1, 1, 1, 1).to(sample.device)
|
313 |
+
sigma_input = sigma if self.state_in_first_order else sigma_next
|
314 |
+
|
315 |
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
316 |
+
if self.config.prediction_type == "epsilon":
|
317 |
+
pred_original_sample = sample - sigma_input * model_output
|
318 |
+
elif self.config.prediction_type == "v_prediction":
|
319 |
+
alpha_prod = 1 / (sigma_input ** 2 + 1)
|
320 |
+
pred_original_sample = (
|
321 |
+
sample * alpha_prod - model_output * (sigma_input * alpha_prod ** .5)
|
322 |
+
)
|
323 |
+
elif self.config.prediction_type == "sample":
|
324 |
+
raise NotImplementedError("prediction_type not implemented yet: sample")
|
325 |
+
else:
|
326 |
+
raise ValueError(
|
327 |
+
f"prediction_type given as {self.config.prediction_type} "
|
328 |
+
"must be one of `epsilon`, or `v_prediction`"
|
329 |
+
)
|
330 |
+
|
331 |
+
if self.state_in_first_order:
|
332 |
+
# 2. Convert to an ODE derivative for 1st order
|
333 |
+
derivative = (sample - pred_original_sample) / sigma
|
334 |
+
# 3. delta timestep
|
335 |
+
dt = sigma_next - sigma
|
336 |
+
|
337 |
+
# store for 2nd order step
|
338 |
+
self.prev_derivative = derivative
|
339 |
+
self.dt = dt
|
340 |
+
self.sample = sample
|
341 |
+
else:
|
342 |
+
# 2. 2nd order / Heun's method
|
343 |
+
derivative = (sample - pred_original_sample) / sigma_next
|
344 |
+
derivative = (self.prev_derivative + derivative) / 2
|
345 |
+
|
346 |
+
# 3. take prev timestep & sample
|
347 |
+
dt = self.dt
|
348 |
+
sample = self.sample
|
349 |
+
|
350 |
+
# free dt and derivative
|
351 |
+
# Note, this puts the scheduler in "first order mode"
|
352 |
+
self.prev_derivative = None
|
353 |
+
self.dt = None
|
354 |
+
self.sample = None
|
355 |
+
|
356 |
+
prev_sample = sample + derivative * dt
|
357 |
+
|
358 |
+
if not return_dict:
|
359 |
+
return (prev_sample,)
|
360 |
+
|
361 |
+
return SchedulerOutput(prev_sample=prev_sample)
|
362 |
+
|
363 |
+
def add_noise(
|
364 |
+
self,
|
365 |
+
original_samples: torch.FloatTensor,
|
366 |
+
noise: torch.FloatTensor,
|
367 |
+
timesteps: torch.FloatTensor,
|
368 |
+
) -> torch.FloatTensor:
|
369 |
+
|
370 |
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
371 |
+
self.sigmas = self.sigmas.to(
|
372 |
+
device=original_samples.device, dtype=original_samples.dtype
|
373 |
+
)
|
374 |
+
self.timesteps = self.timesteps.to(original_samples.device)
|
375 |
+
timesteps = timesteps.to(original_samples.device)
|
376 |
+
|
377 |
+
step_indices = self.index_for_timestep(timesteps)
|
378 |
+
|
379 |
+
sigma = self.sigmas[step_indices].flatten()
|
380 |
+
while len(sigma.shape) < len(original_samples.shape):
|
381 |
+
sigma = sigma.unsqueeze(-1)
|
382 |
+
|
383 |
+
noisy_samples = original_samples + noise * sigma
|
384 |
+
return noisy_samples
|
385 |
+
|
386 |
+
def __len__(self):
|
387 |
+
return self.config.num_train_timesteps
|
diffusers/utils/configuration_utils.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ConfigMixin base class and utilities."""
|
17 |
+
|
18 |
+
import dataclasses
|
19 |
+
import functools
|
20 |
+
import importlib
|
21 |
+
import inspect
|
22 |
+
import json
|
23 |
+
import os
|
24 |
+
import re
|
25 |
+
from collections import OrderedDict
|
26 |
+
from pathlib import PosixPath
|
27 |
+
from typing import Any, Dict, Tuple, Union
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
from huggingface_hub import hf_hub_download
|
31 |
+
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
32 |
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
33 |
+
from requests import HTTPError
|
34 |
+
|
35 |
+
from .import_utils import DummyObject
|
36 |
+
from .deprecation_utils import deprecate
|
37 |
+
from .hub_utils import extract_commit_hash, http_user_agent
|
38 |
+
from .logging import get_logger
|
39 |
+
from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
|
40 |
+
|
41 |
+
|
42 |
+
logger = get_logger(__name__)
|
43 |
+
|
44 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
45 |
+
|
46 |
+
|
47 |
+
class FrozenDict(OrderedDict):
|
48 |
+
def __init__(self, *args, **kwargs):
|
49 |
+
super().__init__(*args, **kwargs)
|
50 |
+
|
51 |
+
for key, value in self.items():
|
52 |
+
setattr(self, key, value)
|
53 |
+
|
54 |
+
self.__frozen = True
|
55 |
+
|
56 |
+
def __delitem__(self, *args, **kwargs):
|
57 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
58 |
+
|
59 |
+
def setdefault(self, *args, **kwargs):
|
60 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
61 |
+
|
62 |
+
def pop(self, *args, **kwargs):
|
63 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
64 |
+
|
65 |
+
def update(self, *args, **kwargs):
|
66 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
67 |
+
|
68 |
+
def __setattr__(self, name, value):
|
69 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
70 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
71 |
+
super().__setattr__(name, value)
|
72 |
+
|
73 |
+
def __setitem__(self, name, value):
|
74 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
75 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
76 |
+
super().__setitem__(name, value)
|
77 |
+
|
78 |
+
|
79 |
+
class ConfigMixin:
|
80 |
+
r"""
|
81 |
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
82 |
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
83 |
+
saving classes that inherit from [`ConfigMixin`].
|
84 |
+
|
85 |
+
Class attributes:
|
86 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
87 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
88 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
89 |
+
overridden by subclass).
|
90 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
91 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
92 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
93 |
+
subclass).
|
94 |
+
"""
|
95 |
+
config_name = None
|
96 |
+
ignore_for_config = []
|
97 |
+
has_compatibles = False
|
98 |
+
|
99 |
+
_deprecated_kwargs = []
|
100 |
+
|
101 |
+
def register_to_config(self, **kwargs):
|
102 |
+
if self.config_name is None:
|
103 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
104 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
105 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
106 |
+
# or solve in a more general way.
|
107 |
+
kwargs.pop("kwargs", None)
|
108 |
+
|
109 |
+
if not hasattr(self, "_internal_dict"):
|
110 |
+
internal_dict = kwargs
|
111 |
+
else:
|
112 |
+
previous_dict = dict(self._internal_dict)
|
113 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
114 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
115 |
+
|
116 |
+
self._internal_dict = FrozenDict(internal_dict)
|
117 |
+
|
118 |
+
def __getattr__(self, name: str) -> Any:
|
119 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
120 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
121 |
+
|
122 |
+
Tihs funtion is mostly copied from PyTorch's __getattr__ overwrite:
|
123 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
124 |
+
"""
|
125 |
+
|
126 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
127 |
+
is_attribute = name in self.__dict__
|
128 |
+
|
129 |
+
if is_in_config and not is_attribute:
|
130 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
131 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
132 |
+
return self._internal_dict[name]
|
133 |
+
|
134 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
135 |
+
|
136 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
137 |
+
"""
|
138 |
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
139 |
+
[`~ConfigMixin.from_config`] class method.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
save_directory (`str` or `os.PathLike`):
|
143 |
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
144 |
+
"""
|
145 |
+
if os.path.isfile(save_directory):
|
146 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
147 |
+
|
148 |
+
os.makedirs(save_directory, exist_ok=True)
|
149 |
+
|
150 |
+
# If we save using the predefined names, we can load using `from_config`
|
151 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
152 |
+
|
153 |
+
self.to_json_file(output_config_file)
|
154 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
155 |
+
|
156 |
+
@classmethod
|
157 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
158 |
+
r"""
|
159 |
+
Instantiate a Python class from a config dictionary.
|
160 |
+
|
161 |
+
Parameters:
|
162 |
+
config (`Dict[str, Any]`):
|
163 |
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
164 |
+
files of compatible classes.
|
165 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
166 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
167 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
168 |
+
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
169 |
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
170 |
+
overwrite the same named arguments in `config`.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
[`ModelMixin`] or [`SchedulerMixin`]:
|
174 |
+
A model or scheduler object instantiated from a config dictionary.
|
175 |
+
|
176 |
+
Examples:
|
177 |
+
|
178 |
+
```python
|
179 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
180 |
+
|
181 |
+
>>> # Download scheduler from huggingface.co and cache.
|
182 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
183 |
+
|
184 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
185 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
186 |
+
|
187 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
188 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
189 |
+
```
|
190 |
+
"""
|
191 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
192 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
193 |
+
if "pretrained_model_name_or_path" in kwargs:
|
194 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
195 |
+
|
196 |
+
if config is None:
|
197 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
198 |
+
# ======>
|
199 |
+
|
200 |
+
if not isinstance(config, dict):
|
201 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
202 |
+
if "Scheduler" in cls.__name__:
|
203 |
+
deprecation_message += (
|
204 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
205 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
206 |
+
" be removed in v1.0.0."
|
207 |
+
)
|
208 |
+
elif "Model" in cls.__name__:
|
209 |
+
deprecation_message += (
|
210 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
211 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
212 |
+
" instead. This functionality will be removed in v1.0.0."
|
213 |
+
)
|
214 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
215 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
216 |
+
|
217 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
218 |
+
|
219 |
+
# Allow dtype to be specified on initialization
|
220 |
+
if "dtype" in unused_kwargs:
|
221 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
222 |
+
|
223 |
+
# add possible deprecated kwargs
|
224 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
225 |
+
if deprecated_kwarg in unused_kwargs:
|
226 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
227 |
+
|
228 |
+
# Return model and optionally state and/or unused_kwargs
|
229 |
+
model = cls(**init_dict)
|
230 |
+
|
231 |
+
# make sure to also save config parameters that might be used for compatible classes
|
232 |
+
model.register_to_config(**hidden_dict)
|
233 |
+
|
234 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
235 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
236 |
+
|
237 |
+
if return_unused_kwargs:
|
238 |
+
return (model, unused_kwargs)
|
239 |
+
else:
|
240 |
+
return model
|
241 |
+
|
242 |
+
@classmethod
|
243 |
+
def get_config_dict(cls, *args, **kwargs):
|
244 |
+
deprecation_message = (
|
245 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
246 |
+
" removed in version v1.0.0"
|
247 |
+
)
|
248 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
249 |
+
return cls.load_config(*args, **kwargs)
|
250 |
+
|
251 |
+
@classmethod
|
252 |
+
def load_config(
|
253 |
+
cls,
|
254 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
255 |
+
return_unused_kwargs=False,
|
256 |
+
return_commit_hash=False,
|
257 |
+
**kwargs,
|
258 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
259 |
+
r"""
|
260 |
+
Load a model or scheduler configuration.
|
261 |
+
|
262 |
+
Parameters:
|
263 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
264 |
+
Can be either:
|
265 |
+
|
266 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
267 |
+
the Hub.
|
268 |
+
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
269 |
+
[`~ConfigMixin.save_config`].
|
270 |
+
|
271 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
272 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
273 |
+
is not used.
|
274 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
275 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
276 |
+
cached versions if they exist.
|
277 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
278 |
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
279 |
+
incompletely downloaded files are deleted.
|
280 |
+
proxies (`Dict[str, str]`, *optional*):
|
281 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
282 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
283 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
284 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
285 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
286 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
287 |
+
won't be downloaded from the Hub.
|
288 |
+
use_auth_token (`str` or *bool*, *optional*):
|
289 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
290 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
291 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
292 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
293 |
+
allowed by Git.
|
294 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
295 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
296 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
297 |
+
Whether unused keyword arguments of the config are returned.
|
298 |
+
return_commit_hash (`bool`, *optional*, defaults to `False):
|
299 |
+
Whether the `commit_hash` of the loaded configuration are returned.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
`dict`:
|
303 |
+
A dictionary of all the parameters stored in a JSON configuration file.
|
304 |
+
|
305 |
+
"""
|
306 |
+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
307 |
+
force_download = kwargs.pop("force_download", False)
|
308 |
+
resume_download = kwargs.pop("resume_download", False)
|
309 |
+
proxies = kwargs.pop("proxies", None)
|
310 |
+
use_auth_token = kwargs.pop("use_auth_token", None)
|
311 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
312 |
+
revision = kwargs.pop("revision", None)
|
313 |
+
_ = kwargs.pop("mirror", None)
|
314 |
+
subfolder = kwargs.pop("subfolder", None)
|
315 |
+
user_agent = kwargs.pop("user_agent", {})
|
316 |
+
|
317 |
+
user_agent = {**user_agent, "file_type": "config"}
|
318 |
+
user_agent = http_user_agent(user_agent)
|
319 |
+
|
320 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
321 |
+
|
322 |
+
if cls.config_name is None:
|
323 |
+
raise ValueError(
|
324 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
325 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
326 |
+
)
|
327 |
+
|
328 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
329 |
+
config_file = pretrained_model_name_or_path
|
330 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
331 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
332 |
+
# Load from a PyTorch checkpoint
|
333 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
334 |
+
elif subfolder is not None and os.path.isfile(
|
335 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
336 |
+
):
|
337 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
338 |
+
else:
|
339 |
+
raise EnvironmentError(
|
340 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
try:
|
344 |
+
# Load from URL or cache if already cached
|
345 |
+
config_file = hf_hub_download(
|
346 |
+
pretrained_model_name_or_path,
|
347 |
+
filename=cls.config_name,
|
348 |
+
cache_dir=cache_dir,
|
349 |
+
force_download=force_download,
|
350 |
+
proxies=proxies,
|
351 |
+
resume_download=resume_download,
|
352 |
+
local_files_only=local_files_only,
|
353 |
+
use_auth_token=use_auth_token,
|
354 |
+
user_agent=user_agent,
|
355 |
+
subfolder=subfolder,
|
356 |
+
revision=revision,
|
357 |
+
)
|
358 |
+
except RepositoryNotFoundError:
|
359 |
+
raise EnvironmentError(
|
360 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
361 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
362 |
+
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
|
363 |
+
" login`."
|
364 |
+
)
|
365 |
+
except RevisionNotFoundError:
|
366 |
+
raise EnvironmentError(
|
367 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
368 |
+
" this model name. Check the model page at"
|
369 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
370 |
+
)
|
371 |
+
except EntryNotFoundError:
|
372 |
+
raise EnvironmentError(
|
373 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
374 |
+
)
|
375 |
+
except HTTPError as err:
|
376 |
+
raise EnvironmentError(
|
377 |
+
"There was a specific connection error when trying to load"
|
378 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
379 |
+
)
|
380 |
+
except ValueError:
|
381 |
+
raise EnvironmentError(
|
382 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
383 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
384 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
385 |
+
" run the library in offline mode at"
|
386 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
387 |
+
)
|
388 |
+
except EnvironmentError:
|
389 |
+
raise EnvironmentError(
|
390 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
391 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
392 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
393 |
+
f"containing a {cls.config_name} file"
|
394 |
+
)
|
395 |
+
|
396 |
+
try:
|
397 |
+
# Load config dict
|
398 |
+
config_dict = cls._dict_from_json_file(config_file)
|
399 |
+
|
400 |
+
commit_hash = extract_commit_hash(config_file)
|
401 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
402 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
403 |
+
|
404 |
+
if not (return_unused_kwargs or return_commit_hash):
|
405 |
+
return config_dict
|
406 |
+
|
407 |
+
outputs = (config_dict,)
|
408 |
+
|
409 |
+
if return_unused_kwargs:
|
410 |
+
outputs += (kwargs,)
|
411 |
+
|
412 |
+
if return_commit_hash:
|
413 |
+
outputs += (commit_hash,)
|
414 |
+
|
415 |
+
return outputs
|
416 |
+
|
417 |
+
@staticmethod
|
418 |
+
def _get_init_keys(cls):
|
419 |
+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
|
420 |
+
|
421 |
+
@classmethod
|
422 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
423 |
+
# 0. Copy origin config dict
|
424 |
+
original_dict = dict(config_dict.items())
|
425 |
+
|
426 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
427 |
+
expected_keys = cls._get_init_keys(cls)
|
428 |
+
expected_keys.remove("self")
|
429 |
+
# remove general kwargs if present in dict
|
430 |
+
if "kwargs" in expected_keys:
|
431 |
+
expected_keys.remove("kwargs")
|
432 |
+
# remove flax internal keys
|
433 |
+
if hasattr(cls, "_flax_internal_args"):
|
434 |
+
for arg in cls._flax_internal_args:
|
435 |
+
expected_keys.remove(arg)
|
436 |
+
|
437 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
438 |
+
# remove keys to be ignored
|
439 |
+
if len(cls.ignore_for_config) > 0:
|
440 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
441 |
+
|
442 |
+
# load diffusers library to import compatible and original scheduler
|
443 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
444 |
+
|
445 |
+
if cls.has_compatibles:
|
446 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
447 |
+
else:
|
448 |
+
compatible_classes = []
|
449 |
+
|
450 |
+
expected_keys_comp_cls = set()
|
451 |
+
for c in compatible_classes:
|
452 |
+
expected_keys_c = cls._get_init_keys(c)
|
453 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
454 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
455 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
456 |
+
|
457 |
+
# remove attributes from orig class that cannot be expected
|
458 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
459 |
+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
|
460 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
461 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
462 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
463 |
+
|
464 |
+
# remove private attributes
|
465 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
466 |
+
|
467 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
468 |
+
init_dict = {}
|
469 |
+
for key in expected_keys:
|
470 |
+
# if config param is passed to kwarg and is present in config dict
|
471 |
+
# it should overwrite existing config dict key
|
472 |
+
if key in kwargs and key in config_dict:
|
473 |
+
config_dict[key] = kwargs.pop(key)
|
474 |
+
|
475 |
+
if key in kwargs:
|
476 |
+
# overwrite key
|
477 |
+
init_dict[key] = kwargs.pop(key)
|
478 |
+
elif key in config_dict:
|
479 |
+
# use value from config dict
|
480 |
+
init_dict[key] = config_dict.pop(key)
|
481 |
+
|
482 |
+
# 4. Give nice warning if unexpected values have been passed
|
483 |
+
if len(config_dict) > 0:
|
484 |
+
logger.warning(
|
485 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
486 |
+
"but are not expected and will be ignored. Please verify your "
|
487 |
+
f"{cls.config_name} configuration file."
|
488 |
+
)
|
489 |
+
|
490 |
+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
|
491 |
+
passed_keys = set(init_dict.keys())
|
492 |
+
if len(expected_keys - passed_keys) > 0:
|
493 |
+
logger.info(
|
494 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
495 |
+
)
|
496 |
+
|
497 |
+
# 6. Define unused keyword arguments
|
498 |
+
unused_kwargs = {**config_dict, **kwargs}
|
499 |
+
|
500 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
501 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
502 |
+
|
503 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
504 |
+
|
505 |
+
@classmethod
|
506 |
+
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
507 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
508 |
+
text = reader.read()
|
509 |
+
return json.loads(text)
|
510 |
+
|
511 |
+
def __repr__(self):
|
512 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
513 |
+
|
514 |
+
@property
|
515 |
+
def config(self) -> Dict[str, Any]:
|
516 |
+
"""
|
517 |
+
Returns the config of the class as a frozen dictionary
|
518 |
+
|
519 |
+
Returns:
|
520 |
+
`Dict[str, Any]`: Config of the class.
|
521 |
+
"""
|
522 |
+
return self._internal_dict
|
523 |
+
|
524 |
+
def to_json_string(self) -> str:
|
525 |
+
"""
|
526 |
+
Serializes the configuration instance to a JSON string.
|
527 |
+
|
528 |
+
Returns:
|
529 |
+
`str`:
|
530 |
+
String containing all the attributes that make up the configuration instance in JSON format.
|
531 |
+
"""
|
532 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
533 |
+
config_dict["_class_name"] = self.__class__.__name__
|
534 |
+
config_dict["_diffusers_version"] = __version__
|
535 |
+
|
536 |
+
def to_json_saveable(value):
|
537 |
+
if isinstance(value, np.ndarray):
|
538 |
+
value = value.tolist()
|
539 |
+
elif isinstance(value, PosixPath):
|
540 |
+
value = str(value)
|
541 |
+
return value
|
542 |
+
|
543 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
544 |
+
# Don't save "_ignore_files"
|
545 |
+
config_dict.pop("_ignore_files", None)
|
546 |
+
|
547 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
548 |
+
|
549 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
550 |
+
"""
|
551 |
+
Save the configuration instance's parameters to a JSON file.
|
552 |
+
|
553 |
+
Args:
|
554 |
+
json_file_path (`str` or `os.PathLike`):
|
555 |
+
Path to the JSON file to save a configuration instance's parameters.
|
556 |
+
"""
|
557 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
558 |
+
writer.write(self.to_json_string())
|
559 |
+
|
560 |
+
|
561 |
+
def register_to_config(init):
|
562 |
+
r"""
|
563 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
564 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
565 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
566 |
+
|
567 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
568 |
+
"""
|
569 |
+
|
570 |
+
@functools.wraps(init)
|
571 |
+
def inner_init(self, *args, **kwargs):
|
572 |
+
# Ignore private kwargs in the init.
|
573 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
574 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
575 |
+
if not isinstance(self, ConfigMixin):
|
576 |
+
raise RuntimeError(
|
577 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
578 |
+
"not inherit from `ConfigMixin`."
|
579 |
+
)
|
580 |
+
|
581 |
+
ignore = getattr(self, "ignore_for_config", [])
|
582 |
+
# Get positional arguments aligned with kwargs
|
583 |
+
new_kwargs = {}
|
584 |
+
signature = inspect.signature(init)
|
585 |
+
parameters = {
|
586 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
587 |
+
}
|
588 |
+
for arg, name in zip(args, parameters.keys()):
|
589 |
+
new_kwargs[name] = arg
|
590 |
+
|
591 |
+
# Then add all kwargs
|
592 |
+
new_kwargs.update(
|
593 |
+
{
|
594 |
+
k: init_kwargs.get(k, default)
|
595 |
+
for k, default in parameters.items()
|
596 |
+
if k not in ignore and k not in new_kwargs
|
597 |
+
}
|
598 |
+
)
|
599 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
600 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
601 |
+
init(self, *args, **init_kwargs)
|
602 |
+
|
603 |
+
return inner_init
|
604 |
+
|
605 |
+
|
606 |
+
def flax_register_to_config(cls):
|
607 |
+
original_init = cls.__init__
|
608 |
+
|
609 |
+
@functools.wraps(original_init)
|
610 |
+
def init(self, *args, **kwargs):
|
611 |
+
if not isinstance(self, ConfigMixin):
|
612 |
+
raise RuntimeError(
|
613 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
614 |
+
"not inherit from `ConfigMixin`."
|
615 |
+
)
|
616 |
+
|
617 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
618 |
+
init_kwargs = dict(kwargs.items())
|
619 |
+
|
620 |
+
# Retrieve default values
|
621 |
+
fields = dataclasses.fields(self)
|
622 |
+
default_kwargs = {}
|
623 |
+
for field in fields:
|
624 |
+
# ignore flax specific attributes
|
625 |
+
if field.name in self._flax_internal_args:
|
626 |
+
continue
|
627 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
628 |
+
default_kwargs[field.name] = None
|
629 |
+
else:
|
630 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
631 |
+
|
632 |
+
# Make sure init_kwargs override default kwargs
|
633 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
634 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
635 |
+
if "dtype" in new_kwargs:
|
636 |
+
new_kwargs.pop("dtype")
|
637 |
+
|
638 |
+
# Get positional arguments aligned with kwargs
|
639 |
+
for i, arg in enumerate(args):
|
640 |
+
name = fields[i].name
|
641 |
+
new_kwargs[name] = arg
|
642 |
+
|
643 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
644 |
+
original_init(self, *args, **kwargs)
|
645 |
+
|
646 |
+
cls.__init__ = init
|
647 |
+
return cls
|
diffusers/utils/constants.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
|
17 |
+
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
|
18 |
+
|
19 |
+
|
20 |
+
default_cache_path = HUGGINGFACE_HUB_CACHE
|
21 |
+
|
22 |
+
|
23 |
+
CONFIG_NAME = "config.json"
|
24 |
+
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
25 |
+
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
26 |
+
ONNX_WEIGHTS_NAME = "model.onnx"
|
27 |
+
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
28 |
+
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
29 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
30 |
+
DIFFUSERS_CACHE = default_cache_path
|
31 |
+
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
32 |
+
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
33 |
+
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
34 |
+
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
|
diffusers/utils/deprecation_utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import warnings
|
3 |
+
from typing import Any, Dict, Optional, Union
|
4 |
+
|
5 |
+
from packaging import version
|
6 |
+
|
7 |
+
|
8 |
+
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
|
9 |
+
from .. import __version__
|
10 |
+
|
11 |
+
deprecated_kwargs = take_from
|
12 |
+
values = ()
|
13 |
+
if not isinstance(args[0], tuple):
|
14 |
+
args = (args,)
|
15 |
+
|
16 |
+
for attribute, version_name, message in args:
|
17 |
+
if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
|
18 |
+
raise ValueError(
|
19 |
+
f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'"
|
20 |
+
f" version {__version__} is >= {version_name}"
|
21 |
+
)
|
22 |
+
|
23 |
+
warning = None
|
24 |
+
if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs:
|
25 |
+
values += (deprecated_kwargs.pop(attribute),)
|
26 |
+
warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}."
|
27 |
+
elif hasattr(deprecated_kwargs, attribute):
|
28 |
+
values += (getattr(deprecated_kwargs, attribute),)
|
29 |
+
warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}."
|
30 |
+
elif deprecated_kwargs is None:
|
31 |
+
warning = f"`{attribute}` is deprecated and will be removed in version {version_name}."
|
32 |
+
|
33 |
+
if warning is not None:
|
34 |
+
warning = warning + " " if standard_warn else ""
|
35 |
+
warnings.warn(warning + message, FutureWarning, stacklevel=stacklevel)
|
36 |
+
|
37 |
+
if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0:
|
38 |
+
call_frame = inspect.getouterframes(inspect.currentframe())[1]
|
39 |
+
filename = call_frame.filename
|
40 |
+
line_number = call_frame.lineno
|
41 |
+
function = call_frame.function
|
42 |
+
key, value = next(iter(deprecated_kwargs.items()))
|
43 |
+
raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`")
|
44 |
+
|
45 |
+
if len(values) == 0:
|
46 |
+
return
|
47 |
+
elif len(values) == 1:
|
48 |
+
return values[0]
|
49 |
+
return values
|
diffusers/utils/hub_utils.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import sys
|
20 |
+
import traceback
|
21 |
+
import warnings
|
22 |
+
from pathlib import Path
|
23 |
+
from typing import Dict, Optional, Union
|
24 |
+
from uuid import uuid4
|
25 |
+
|
26 |
+
from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
|
27 |
+
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
28 |
+
from huggingface_hub.utils import (
|
29 |
+
EntryNotFoundError,
|
30 |
+
RepositoryNotFoundError,
|
31 |
+
RevisionNotFoundError,
|
32 |
+
is_jinja_available,
|
33 |
+
)
|
34 |
+
from packaging import version
|
35 |
+
from requests import HTTPError
|
36 |
+
|
37 |
+
from .constants import (
|
38 |
+
DEPRECATED_REVISION_ARGS,
|
39 |
+
DIFFUSERS_CACHE,
|
40 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
41 |
+
SAFETENSORS_WEIGHTS_NAME,
|
42 |
+
WEIGHTS_NAME,
|
43 |
+
)
|
44 |
+
from .import_utils import (
|
45 |
+
ENV_VARS_TRUE_VALUES,
|
46 |
+
_flax_version,
|
47 |
+
_jax_version,
|
48 |
+
_onnxruntime_version,
|
49 |
+
_torch_version,
|
50 |
+
is_flax_available,
|
51 |
+
is_onnx_available,
|
52 |
+
is_torch_available,
|
53 |
+
)
|
54 |
+
from .logging import get_logger
|
55 |
+
|
56 |
+
|
57 |
+
logger = get_logger(__name__)
|
58 |
+
|
59 |
+
|
60 |
+
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
|
61 |
+
SESSION_ID = uuid4().hex
|
62 |
+
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
|
63 |
+
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
64 |
+
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
|
65 |
+
|
66 |
+
|
67 |
+
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
68 |
+
"""
|
69 |
+
Formats a user-agent string with basic info about a request.
|
70 |
+
"""
|
71 |
+
ua = f"diffusers; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
72 |
+
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
|
73 |
+
return ua + "; telemetry/off"
|
74 |
+
if is_torch_available():
|
75 |
+
ua += f"; torch/{_torch_version}"
|
76 |
+
if is_flax_available():
|
77 |
+
ua += f"; jax/{_jax_version}"
|
78 |
+
ua += f"; flax/{_flax_version}"
|
79 |
+
if is_onnx_available():
|
80 |
+
ua += f"; onnxruntime/{_onnxruntime_version}"
|
81 |
+
# CI will set this value to True
|
82 |
+
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
83 |
+
ua += "; is_ci/true"
|
84 |
+
if isinstance(user_agent, dict):
|
85 |
+
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
86 |
+
elif isinstance(user_agent, str):
|
87 |
+
ua += "; " + user_agent
|
88 |
+
return ua
|
89 |
+
|
90 |
+
|
91 |
+
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
92 |
+
if token is None:
|
93 |
+
token = HfFolder.get_token()
|
94 |
+
if organization is None:
|
95 |
+
username = whoami(token)["name"]
|
96 |
+
return f"{username}/{model_id}"
|
97 |
+
else:
|
98 |
+
return f"{organization}/{model_id}"
|
99 |
+
|
100 |
+
|
101 |
+
def create_model_card(args, model_name):
|
102 |
+
if not is_jinja_available():
|
103 |
+
raise ValueError(
|
104 |
+
"Modelcard rendering is based on Jinja templates."
|
105 |
+
" Please make sure to have `jinja` installed before using `create_model_card`."
|
106 |
+
" To install it, please run `pip install Jinja2`."
|
107 |
+
)
|
108 |
+
|
109 |
+
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
110 |
+
return
|
111 |
+
|
112 |
+
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
113 |
+
repo_name = get_full_repo_name(model_name, token=hub_token)
|
114 |
+
|
115 |
+
model_card = ModelCard.from_template(
|
116 |
+
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
|
117 |
+
language="en",
|
118 |
+
license="apache-2.0",
|
119 |
+
library_name="diffusers",
|
120 |
+
tags=[],
|
121 |
+
datasets=args.dataset_name,
|
122 |
+
metrics=[],
|
123 |
+
),
|
124 |
+
template_path=MODEL_CARD_TEMPLATE_PATH,
|
125 |
+
model_name=model_name,
|
126 |
+
repo_name=repo_name,
|
127 |
+
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
|
128 |
+
learning_rate=args.learning_rate,
|
129 |
+
train_batch_size=args.train_batch_size,
|
130 |
+
eval_batch_size=args.eval_batch_size,
|
131 |
+
gradient_accumulation_steps=(
|
132 |
+
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
|
133 |
+
),
|
134 |
+
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
|
135 |
+
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
|
136 |
+
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
|
137 |
+
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
|
138 |
+
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
|
139 |
+
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
|
140 |
+
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
|
141 |
+
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
|
142 |
+
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
|
143 |
+
mixed_precision=args.mixed_precision,
|
144 |
+
)
|
145 |
+
|
146 |
+
card_path = os.path.join(args.output_dir, "README.md")
|
147 |
+
model_card.save(card_path)
|
148 |
+
|
149 |
+
|
150 |
+
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
|
151 |
+
"""
|
152 |
+
Extracts the commit hash from a resolved filename toward a cache file.
|
153 |
+
"""
|
154 |
+
if resolved_file is None or commit_hash is not None:
|
155 |
+
return commit_hash
|
156 |
+
resolved_file = str(Path(resolved_file).as_posix())
|
157 |
+
search = re.search(r"snapshots/([^/]+)/", resolved_file)
|
158 |
+
if search is None:
|
159 |
+
return None
|
160 |
+
commit_hash = search.groups()[0]
|
161 |
+
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
|
162 |
+
|
163 |
+
|
164 |
+
# Old default cache path, potentially to be migrated.
|
165 |
+
# This logic was more or less taken from `transformers`, with the following differences:
|
166 |
+
# - Diffusers doesn't use custom environment variables to specify the cache path.
|
167 |
+
# - There is no need to migrate the cache format, just move the files to the new location.
|
168 |
+
hf_cache_home = os.path.expanduser(
|
169 |
+
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
|
170 |
+
)
|
171 |
+
old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
|
172 |
+
|
173 |
+
|
174 |
+
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
|
175 |
+
if new_cache_dir is None:
|
176 |
+
new_cache_dir = DIFFUSERS_CACHE
|
177 |
+
if old_cache_dir is None:
|
178 |
+
old_cache_dir = old_diffusers_cache
|
179 |
+
|
180 |
+
old_cache_dir = Path(old_cache_dir).expanduser()
|
181 |
+
new_cache_dir = Path(new_cache_dir).expanduser()
|
182 |
+
for old_blob_path in old_cache_dir.glob("**/blobs/*"):
|
183 |
+
if old_blob_path.is_file() and not old_blob_path.is_symlink():
|
184 |
+
new_blob_path = new_cache_dir / old_blob_path.relative_to(old_cache_dir)
|
185 |
+
new_blob_path.parent.mkdir(parents=True, exist_ok=True)
|
186 |
+
os.replace(old_blob_path, new_blob_path)
|
187 |
+
try:
|
188 |
+
os.symlink(new_blob_path, old_blob_path)
|
189 |
+
except OSError:
|
190 |
+
logger.warning(
|
191 |
+
"Could not create symlink between old cache and new cache. If you use an older version of diffusers again, files will be re-downloaded."
|
192 |
+
)
|
193 |
+
# At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
|
194 |
+
|
195 |
+
|
196 |
+
cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
|
197 |
+
if not os.path.isfile(cache_version_file):
|
198 |
+
cache_version = 0
|
199 |
+
else:
|
200 |
+
with open(cache_version_file) as f:
|
201 |
+
cache_version = int(f.read())
|
202 |
+
|
203 |
+
if cache_version < 1:
|
204 |
+
old_cache_is_not_empty = os.path.isdir(old_diffusers_cache) and len(os.listdir(old_diffusers_cache)) > 0
|
205 |
+
if old_cache_is_not_empty:
|
206 |
+
logger.warning(
|
207 |
+
"The cache for model files in Diffusers v0.14.0 has moved to a new location. Moving your "
|
208 |
+
"existing cached models. This is a one-time operation, you can interrupt it or run it "
|
209 |
+
"later by calling `diffusers.utils.hub_utils.move_cache()`."
|
210 |
+
)
|
211 |
+
try:
|
212 |
+
move_cache()
|
213 |
+
except Exception as e:
|
214 |
+
trace = "\n".join(traceback.format_tb(e.__traceback__))
|
215 |
+
logger.error(
|
216 |
+
f"There was a problem when trying to move your cache:\n\n{trace}\n{e.__class__.__name__}: {e}\n\nPlease "
|
217 |
+
"file an issue at https://github.com/huggingface/diffusers/issues/new/choose, copy paste this whole "
|
218 |
+
"message and we will do our best to help."
|
219 |
+
)
|
220 |
+
|
221 |
+
if cache_version < 1:
|
222 |
+
try:
|
223 |
+
os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
|
224 |
+
with open(cache_version_file, "w") as f:
|
225 |
+
f.write("1")
|
226 |
+
except Exception:
|
227 |
+
logger.warning(
|
228 |
+
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
|
229 |
+
"the directory exists and can be written to."
|
230 |
+
)
|
231 |
+
|
232 |
+
|
233 |
+
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
234 |
+
if variant is not None:
|
235 |
+
splits = weights_name.split(".")
|
236 |
+
splits = splits[:-1] + [variant] + splits[-1:]
|
237 |
+
weights_name = ".".join(splits)
|
238 |
+
|
239 |
+
return weights_name
|
240 |
+
|
241 |
+
|
242 |
+
def _get_model_file(
|
243 |
+
pretrained_model_name_or_path,
|
244 |
+
*,
|
245 |
+
weights_name,
|
246 |
+
subfolder,
|
247 |
+
cache_dir,
|
248 |
+
force_download,
|
249 |
+
proxies,
|
250 |
+
resume_download,
|
251 |
+
local_files_only,
|
252 |
+
use_auth_token,
|
253 |
+
user_agent,
|
254 |
+
revision,
|
255 |
+
commit_hash=None,
|
256 |
+
):
|
257 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
258 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
259 |
+
return pretrained_model_name_or_path
|
260 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
261 |
+
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
262 |
+
# Load from a PyTorch checkpoint
|
263 |
+
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
264 |
+
return model_file
|
265 |
+
elif subfolder is not None and os.path.isfile(
|
266 |
+
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
267 |
+
):
|
268 |
+
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
269 |
+
return model_file
|
270 |
+
else:
|
271 |
+
raise EnvironmentError(
|
272 |
+
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
273 |
+
)
|
274 |
+
else:
|
275 |
+
# 1. First check if deprecated way of loading from branches is used
|
276 |
+
if (
|
277 |
+
revision in DEPRECATED_REVISION_ARGS
|
278 |
+
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
|
279 |
+
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
|
280 |
+
):
|
281 |
+
try:
|
282 |
+
model_file = hf_hub_download(
|
283 |
+
pretrained_model_name_or_path,
|
284 |
+
filename=_add_variant(weights_name, revision),
|
285 |
+
cache_dir=cache_dir,
|
286 |
+
force_download=force_download,
|
287 |
+
proxies=proxies,
|
288 |
+
resume_download=resume_download,
|
289 |
+
local_files_only=local_files_only,
|
290 |
+
use_auth_token=use_auth_token,
|
291 |
+
user_agent=user_agent,
|
292 |
+
subfolder=subfolder,
|
293 |
+
revision=revision or commit_hash,
|
294 |
+
)
|
295 |
+
warnings.warn(
|
296 |
+
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
297 |
+
FutureWarning,
|
298 |
+
)
|
299 |
+
return model_file
|
300 |
+
except: # noqa: E722
|
301 |
+
warnings.warn(
|
302 |
+
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
|
303 |
+
FutureWarning,
|
304 |
+
)
|
305 |
+
try:
|
306 |
+
# 2. Load model file as usual
|
307 |
+
model_file = hf_hub_download(
|
308 |
+
pretrained_model_name_or_path,
|
309 |
+
filename=weights_name,
|
310 |
+
cache_dir=cache_dir,
|
311 |
+
force_download=force_download,
|
312 |
+
proxies=proxies,
|
313 |
+
resume_download=resume_download,
|
314 |
+
local_files_only=local_files_only,
|
315 |
+
use_auth_token=use_auth_token,
|
316 |
+
user_agent=user_agent,
|
317 |
+
subfolder=subfolder,
|
318 |
+
revision=revision or commit_hash,
|
319 |
+
)
|
320 |
+
return model_file
|
321 |
+
|
322 |
+
except RepositoryNotFoundError:
|
323 |
+
raise EnvironmentError(
|
324 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
325 |
+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
326 |
+
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
327 |
+
"login`."
|
328 |
+
)
|
329 |
+
except RevisionNotFoundError:
|
330 |
+
raise EnvironmentError(
|
331 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
332 |
+
"this model name. Check the model page at "
|
333 |
+
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
334 |
+
)
|
335 |
+
except EntryNotFoundError:
|
336 |
+
raise EnvironmentError(
|
337 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
338 |
+
)
|
339 |
+
except HTTPError as err:
|
340 |
+
raise EnvironmentError(
|
341 |
+
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
342 |
+
)
|
343 |
+
except ValueError:
|
344 |
+
raise EnvironmentError(
|
345 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
346 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
347 |
+
f" directory containing a file named {weights_name} or"
|
348 |
+
" \nCheckout your internet connection or see how to run the library in"
|
349 |
+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
350 |
+
)
|
351 |
+
except EnvironmentError:
|
352 |
+
raise EnvironmentError(
|
353 |
+
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
354 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
355 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
356 |
+
f"containing a file named {weights_name}"
|
357 |
+
)
|
diffusers/utils/import_utils.py
ADDED
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Import utilities: Utilities related to imports and our lazy inits.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import importlib.util
|
19 |
+
import operator as op
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
from collections import OrderedDict
|
23 |
+
from typing import Union
|
24 |
+
|
25 |
+
from packaging import version
|
26 |
+
from packaging.version import Version, parse
|
27 |
+
|
28 |
+
from . import logging
|
29 |
+
|
30 |
+
|
31 |
+
# The package importlib_metadata is in a different place, depending on the python version.
|
32 |
+
if sys.version_info < (3, 8):
|
33 |
+
import importlib_metadata
|
34 |
+
else:
|
35 |
+
import importlib.metadata as importlib_metadata
|
36 |
+
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
|
41 |
+
ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
42 |
+
|
43 |
+
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
44 |
+
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
45 |
+
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
46 |
+
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
|
47 |
+
|
48 |
+
STR_OPERATION_TO_FUNC = {
|
49 |
+
">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt
|
50 |
+
}
|
51 |
+
|
52 |
+
_torch_version = "N/A"
|
53 |
+
if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
|
54 |
+
_torch_available = importlib.util.find_spec("torch") is not None
|
55 |
+
if _torch_available:
|
56 |
+
try:
|
57 |
+
_torch_version = importlib_metadata.version("torch")
|
58 |
+
logger.info(f"PyTorch version {_torch_version} available.")
|
59 |
+
except importlib_metadata.PackageNotFoundError:
|
60 |
+
_torch_available = False
|
61 |
+
else:
|
62 |
+
logger.info("Disabling PyTorch because USE_TORCH is set")
|
63 |
+
_torch_available = False
|
64 |
+
|
65 |
+
|
66 |
+
_tf_version = "N/A"
|
67 |
+
if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
|
68 |
+
_tf_available = importlib.util.find_spec("tensorflow") is not None
|
69 |
+
if _tf_available:
|
70 |
+
candidates = (
|
71 |
+
"tensorflow",
|
72 |
+
"tensorflow-cpu",
|
73 |
+
"tensorflow-gpu",
|
74 |
+
"tf-nightly",
|
75 |
+
"tf-nightly-cpu",
|
76 |
+
"tf-nightly-gpu",
|
77 |
+
"intel-tensorflow",
|
78 |
+
"intel-tensorflow-avx512",
|
79 |
+
"tensorflow-rocm",
|
80 |
+
"tensorflow-macos",
|
81 |
+
"tensorflow-aarch64",
|
82 |
+
)
|
83 |
+
_tf_version = None
|
84 |
+
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
|
85 |
+
for pkg in candidates:
|
86 |
+
try:
|
87 |
+
_tf_version = importlib_metadata.version(pkg)
|
88 |
+
break
|
89 |
+
except importlib_metadata.PackageNotFoundError:
|
90 |
+
pass
|
91 |
+
_tf_available = _tf_version is not None
|
92 |
+
if _tf_available:
|
93 |
+
if version.parse(_tf_version) < version.parse("2"):
|
94 |
+
logger.info(f"TensorFlow found but with version {_tf_version}. "
|
95 |
+
"Diffusers requires version 2 minimum.")
|
96 |
+
_tf_available = False
|
97 |
+
else:
|
98 |
+
logger.info(f"TensorFlow version {_tf_version} available.")
|
99 |
+
else:
|
100 |
+
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
101 |
+
_tf_available = False
|
102 |
+
|
103 |
+
_jax_version = "N/A"
|
104 |
+
_flax_version = "N/A"
|
105 |
+
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
106 |
+
_flax_available = importlib.util.find_spec("jax") is not None \
|
107 |
+
and importlib.util.find_spec("flax") is not None
|
108 |
+
if _flax_available:
|
109 |
+
try:
|
110 |
+
_jax_version = importlib_metadata.version("jax")
|
111 |
+
_flax_version = importlib_metadata.version("flax")
|
112 |
+
logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
|
113 |
+
except importlib_metadata.PackageNotFoundError:
|
114 |
+
_flax_available = False
|
115 |
+
else:
|
116 |
+
_flax_available = False
|
117 |
+
|
118 |
+
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
119 |
+
_safetensors_available = importlib.util.find_spec("safetensors") is not None
|
120 |
+
if _safetensors_available:
|
121 |
+
try:
|
122 |
+
_safetensors_version = importlib_metadata.version("safetensors")
|
123 |
+
logger.info(f"Safetensors version {_safetensors_version} available.")
|
124 |
+
except importlib_metadata.PackageNotFoundError:
|
125 |
+
_safetensors_available = False
|
126 |
+
else:
|
127 |
+
logger.info("Disabling Safetensors because USE_TF is set")
|
128 |
+
_safetensors_available = False
|
129 |
+
|
130 |
+
_transformers_available = importlib.util.find_spec("transformers") is not None
|
131 |
+
try:
|
132 |
+
_transformers_version = importlib_metadata.version("transformers")
|
133 |
+
logger.debug(f"Successfully imported transformers version {_transformers_version}")
|
134 |
+
except importlib_metadata.PackageNotFoundError:
|
135 |
+
_transformers_available = False
|
136 |
+
|
137 |
+
|
138 |
+
_inflect_available = importlib.util.find_spec("inflect") is not None
|
139 |
+
try:
|
140 |
+
_inflect_version = importlib_metadata.version("inflect")
|
141 |
+
logger.debug(f"Successfully imported inflect version {_inflect_version}")
|
142 |
+
except importlib_metadata.PackageNotFoundError:
|
143 |
+
_inflect_available = False
|
144 |
+
|
145 |
+
|
146 |
+
_unidecode_available = importlib.util.find_spec("unidecode") is not None
|
147 |
+
try:
|
148 |
+
_unidecode_version = importlib_metadata.version("unidecode")
|
149 |
+
logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
|
150 |
+
except importlib_metadata.PackageNotFoundError:
|
151 |
+
_unidecode_available = False
|
152 |
+
|
153 |
+
|
154 |
+
_onnxruntime_version = "N/A"
|
155 |
+
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
156 |
+
if _onnx_available:
|
157 |
+
candidates = (
|
158 |
+
"onnxruntime",
|
159 |
+
"onnxruntime-gpu",
|
160 |
+
"ort_nightly_gpu",
|
161 |
+
"onnxruntime-directml",
|
162 |
+
"onnxruntime-openvino",
|
163 |
+
"ort_nightly_directml",
|
164 |
+
"onnxruntime-rocm",
|
165 |
+
"onnxruntime-training",
|
166 |
+
)
|
167 |
+
_onnxruntime_version = None
|
168 |
+
# For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
|
169 |
+
for pkg in candidates:
|
170 |
+
try:
|
171 |
+
_onnxruntime_version = importlib_metadata.version(pkg)
|
172 |
+
break
|
173 |
+
except importlib_metadata.PackageNotFoundError:
|
174 |
+
pass
|
175 |
+
_onnx_available = _onnxruntime_version is not None
|
176 |
+
if _onnx_available:
|
177 |
+
logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
|
178 |
+
|
179 |
+
# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed.
|
180 |
+
# _opencv_available = importlib.util.find_spec("opencv-python") is not None
|
181 |
+
try:
|
182 |
+
candidates = (
|
183 |
+
"opencv-python",
|
184 |
+
"opencv-contrib-python",
|
185 |
+
"opencv-python-headless",
|
186 |
+
"opencv-contrib-python-headless",
|
187 |
+
)
|
188 |
+
_opencv_version = None
|
189 |
+
for pkg in candidates:
|
190 |
+
try:
|
191 |
+
_opencv_version = importlib_metadata.version(pkg)
|
192 |
+
break
|
193 |
+
except importlib_metadata.PackageNotFoundError:
|
194 |
+
pass
|
195 |
+
_opencv_available = _opencv_version is not None
|
196 |
+
if _opencv_available:
|
197 |
+
logger.debug(f"Successfully imported cv2 version {_opencv_version}")
|
198 |
+
except importlib_metadata.PackageNotFoundError:
|
199 |
+
_opencv_available = False
|
200 |
+
|
201 |
+
_scipy_available = importlib.util.find_spec("scipy") is not None
|
202 |
+
try:
|
203 |
+
_scipy_version = importlib_metadata.version("scipy")
|
204 |
+
logger.debug(f"Successfully imported scipy version {_scipy_version}")
|
205 |
+
except importlib_metadata.PackageNotFoundError:
|
206 |
+
_scipy_available = False
|
207 |
+
|
208 |
+
_librosa_available = importlib.util.find_spec("librosa") is not None
|
209 |
+
try:
|
210 |
+
_librosa_version = importlib_metadata.version("librosa")
|
211 |
+
logger.debug(f"Successfully imported librosa version {_librosa_version}")
|
212 |
+
except importlib_metadata.PackageNotFoundError:
|
213 |
+
_librosa_available = False
|
214 |
+
|
215 |
+
_accelerate_available = importlib.util.find_spec("accelerate") is not None
|
216 |
+
try:
|
217 |
+
_accelerate_version = importlib_metadata.version("accelerate")
|
218 |
+
logger.debug(f"Successfully imported accelerate version {_accelerate_version}")
|
219 |
+
except importlib_metadata.PackageNotFoundError:
|
220 |
+
_accelerate_available = False
|
221 |
+
|
222 |
+
_xformers_available = importlib.util.find_spec("xformers") is not None
|
223 |
+
try:
|
224 |
+
_xformers_version = importlib_metadata.version("xformers")
|
225 |
+
if _torch_available:
|
226 |
+
import torch
|
227 |
+
|
228 |
+
if version.Version(torch.__version__) < version.Version("1.12"):
|
229 |
+
raise ValueError("PyTorch should be >= 1.12")
|
230 |
+
logger.debug(f"Successfully imported xformers version {_xformers_version}")
|
231 |
+
except importlib_metadata.PackageNotFoundError:
|
232 |
+
_xformers_available = False
|
233 |
+
|
234 |
+
_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None
|
235 |
+
try:
|
236 |
+
_k_diffusion_version = importlib_metadata.version("k_diffusion")
|
237 |
+
logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}")
|
238 |
+
except importlib_metadata.PackageNotFoundError:
|
239 |
+
_k_diffusion_available = False
|
240 |
+
|
241 |
+
_note_seq_available = importlib.util.find_spec("note_seq") is not None
|
242 |
+
try:
|
243 |
+
_note_seq_version = importlib_metadata.version("note_seq")
|
244 |
+
logger.debug(f"Successfully imported note-seq version {_note_seq_version}")
|
245 |
+
except importlib_metadata.PackageNotFoundError:
|
246 |
+
_note_seq_available = False
|
247 |
+
|
248 |
+
_wandb_available = importlib.util.find_spec("wandb") is not None
|
249 |
+
try:
|
250 |
+
_wandb_version = importlib_metadata.version("wandb")
|
251 |
+
logger.debug(f"Successfully imported wandb version {_wandb_version }")
|
252 |
+
except importlib_metadata.PackageNotFoundError:
|
253 |
+
_wandb_available = False
|
254 |
+
|
255 |
+
_omegaconf_available = importlib.util.find_spec("omegaconf") is not None
|
256 |
+
try:
|
257 |
+
_omegaconf_version = importlib_metadata.version("omegaconf")
|
258 |
+
logger.debug(f"Successfully imported omegaconf version {_omegaconf_version}")
|
259 |
+
except importlib_metadata.PackageNotFoundError:
|
260 |
+
_omegaconf_available = False
|
261 |
+
|
262 |
+
_tensorboard_available = importlib.util.find_spec("tensorboard")
|
263 |
+
try:
|
264 |
+
_tensorboard_version = importlib_metadata.version("tensorboard")
|
265 |
+
logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}")
|
266 |
+
except importlib_metadata.PackageNotFoundError:
|
267 |
+
_tensorboard_available = False
|
268 |
+
|
269 |
+
|
270 |
+
_compel_available = importlib.util.find_spec("compel")
|
271 |
+
try:
|
272 |
+
_compel_version = importlib_metadata.version("compel")
|
273 |
+
logger.debug(f"Successfully imported compel version {_compel_version}")
|
274 |
+
except importlib_metadata.PackageNotFoundError:
|
275 |
+
_compel_available = False
|
276 |
+
|
277 |
+
|
278 |
+
_ftfy_available = importlib.util.find_spec("ftfy") is not None
|
279 |
+
try:
|
280 |
+
_ftfy_version = importlib_metadata.version("ftfy")
|
281 |
+
logger.debug(f"Successfully imported ftfy version {_ftfy_version}")
|
282 |
+
except importlib_metadata.PackageNotFoundError:
|
283 |
+
_ftfy_available = False
|
284 |
+
|
285 |
+
|
286 |
+
_bs4_available = importlib.util.find_spec("bs4") is not None
|
287 |
+
try:
|
288 |
+
# importlib metadata under different name
|
289 |
+
_bs4_version = importlib_metadata.version("beautifulsoup4")
|
290 |
+
logger.debug(f"Successfully imported ftfy version {_bs4_version}")
|
291 |
+
except importlib_metadata.PackageNotFoundError:
|
292 |
+
_bs4_available = False
|
293 |
+
|
294 |
+
_torchsde_available = importlib.util.find_spec("torchsde") is not None
|
295 |
+
try:
|
296 |
+
_torchsde_version = importlib_metadata.version("torchsde")
|
297 |
+
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
|
298 |
+
except importlib_metadata.PackageNotFoundError:
|
299 |
+
_torchsde_available = False
|
300 |
+
|
301 |
+
|
302 |
+
def is_torch_available():
|
303 |
+
return _torch_available
|
304 |
+
|
305 |
+
|
306 |
+
def is_safetensors_available():
|
307 |
+
return _safetensors_available
|
308 |
+
|
309 |
+
|
310 |
+
def is_tf_available():
|
311 |
+
return _tf_available
|
312 |
+
|
313 |
+
|
314 |
+
def is_flax_available():
|
315 |
+
return _flax_available
|
316 |
+
|
317 |
+
|
318 |
+
def is_transformers_available():
|
319 |
+
return _transformers_available
|
320 |
+
|
321 |
+
|
322 |
+
def is_inflect_available():
|
323 |
+
return _inflect_available
|
324 |
+
|
325 |
+
|
326 |
+
def is_unidecode_available():
|
327 |
+
return _unidecode_available
|
328 |
+
|
329 |
+
|
330 |
+
def is_onnx_available():
|
331 |
+
return _onnx_available
|
332 |
+
|
333 |
+
|
334 |
+
def is_opencv_available():
|
335 |
+
return _opencv_available
|
336 |
+
|
337 |
+
|
338 |
+
def is_scipy_available():
|
339 |
+
return _scipy_available
|
340 |
+
|
341 |
+
|
342 |
+
def is_librosa_available():
|
343 |
+
return _librosa_available
|
344 |
+
|
345 |
+
|
346 |
+
def is_xformers_available():
|
347 |
+
return _xformers_available
|
348 |
+
|
349 |
+
|
350 |
+
def is_accelerate_available():
|
351 |
+
return _accelerate_available
|
352 |
+
|
353 |
+
|
354 |
+
def is_k_diffusion_available():
|
355 |
+
return _k_diffusion_available
|
356 |
+
|
357 |
+
|
358 |
+
def is_note_seq_available():
|
359 |
+
return _note_seq_available
|
360 |
+
|
361 |
+
|
362 |
+
def is_wandb_available():
|
363 |
+
return _wandb_available
|
364 |
+
|
365 |
+
|
366 |
+
def is_omegaconf_available():
|
367 |
+
return _omegaconf_available
|
368 |
+
|
369 |
+
|
370 |
+
def is_tensorboard_available():
|
371 |
+
return _tensorboard_available
|
372 |
+
|
373 |
+
|
374 |
+
def is_compel_available():
|
375 |
+
return _compel_available
|
376 |
+
|
377 |
+
|
378 |
+
def is_ftfy_available():
|
379 |
+
return _ftfy_available
|
380 |
+
|
381 |
+
|
382 |
+
def is_bs4_available():
|
383 |
+
return _bs4_available
|
384 |
+
|
385 |
+
|
386 |
+
def is_torchsde_available():
|
387 |
+
return _torchsde_available
|
388 |
+
|
389 |
+
|
390 |
+
# docstyle-ignore
|
391 |
+
FLAX_IMPORT_ERROR = """
|
392 |
+
{0} requires the FLAX library but it was not found in your environment.
|
393 |
+
Checkout the instructions on the installation page: https://github.com/google/flax
|
394 |
+
and follow the ones that match your environment.
|
395 |
+
"""
|
396 |
+
|
397 |
+
# docstyle-ignore
|
398 |
+
INFLECT_IMPORT_ERROR = """
|
399 |
+
{0} requires the inflect library but it was not found in your environment.
|
400 |
+
You can install it with pip: `pip install inflect`
|
401 |
+
"""
|
402 |
+
|
403 |
+
# docstyle-ignore
|
404 |
+
PYTORCH_IMPORT_ERROR = """
|
405 |
+
{0} requires the PyTorch library but it was not found in your environment.
|
406 |
+
Checkout the instructions on the installation page: https://pytorch.org/get-started/locally/
|
407 |
+
and follow the ones that match your environment.
|
408 |
+
"""
|
409 |
+
|
410 |
+
# docstyle-ignore
|
411 |
+
ONNX_IMPORT_ERROR = """
|
412 |
+
{0} requires the onnxruntime library but it was not found in your environment.
|
413 |
+
You can install it with pip: `pip install onnxruntime`
|
414 |
+
"""
|
415 |
+
|
416 |
+
# docstyle-ignore
|
417 |
+
OPENCV_IMPORT_ERROR = """
|
418 |
+
{0} requires the OpenCV library but it was not found in your environment.
|
419 |
+
You can install it with pip: `pip install opencv-python`
|
420 |
+
"""
|
421 |
+
|
422 |
+
# docstyle-ignore
|
423 |
+
SCIPY_IMPORT_ERROR = """
|
424 |
+
{0} requires the scipy library but it was not found in your environment.
|
425 |
+
You can install it with pip: `pip install scipy`
|
426 |
+
"""
|
427 |
+
|
428 |
+
# docstyle-ignore
|
429 |
+
LIBROSA_IMPORT_ERROR = """
|
430 |
+
{0} requires the librosa library but it was not found in your environment.
|
431 |
+
Checkout the instructions on the installation page: https://librosa.org/doc/latest/install.html
|
432 |
+
and follow the ones that match your environment.
|
433 |
+
"""
|
434 |
+
|
435 |
+
# docstyle-ignore
|
436 |
+
TRANSFORMERS_IMPORT_ERROR = """
|
437 |
+
{0} requires the transformers library but it was not found in your environment.
|
438 |
+
You can install it with pip: `pip install transformers`
|
439 |
+
"""
|
440 |
+
|
441 |
+
# docstyle-ignore
|
442 |
+
UNIDECODE_IMPORT_ERROR = """
|
443 |
+
{0} requires the unidecode library but it was not found in your environment.
|
444 |
+
You can install it with pip: `pip install Unidecode`
|
445 |
+
"""
|
446 |
+
|
447 |
+
# docstyle-ignore
|
448 |
+
K_DIFFUSION_IMPORT_ERROR = """
|
449 |
+
{0} requires the k-diffusion library but it was not found in your environment.
|
450 |
+
You can install it with pip: `pip install k-diffusion`
|
451 |
+
"""
|
452 |
+
|
453 |
+
# docstyle-ignore
|
454 |
+
NOTE_SEQ_IMPORT_ERROR = """
|
455 |
+
{0} requires the note-seq library but it was not found in your environment.
|
456 |
+
You can install it with pip: `pip install note-seq`
|
457 |
+
"""
|
458 |
+
|
459 |
+
# docstyle-ignore
|
460 |
+
WANDB_IMPORT_ERROR = """
|
461 |
+
{0} requires the wandb library but it was not found in your environment.
|
462 |
+
You can install it with pip: `pip install wandb`
|
463 |
+
"""
|
464 |
+
|
465 |
+
# docstyle-ignore
|
466 |
+
OMEGACONF_IMPORT_ERROR = """
|
467 |
+
{0} requires the omegaconf library but it was not found in your environment.
|
468 |
+
You can install it with pip: `pip install omegaconf`
|
469 |
+
"""
|
470 |
+
|
471 |
+
# docstyle-ignore
|
472 |
+
TENSORBOARD_IMPORT_ERROR = """
|
473 |
+
{0} requires the tensorboard library but it was not found in your environment.
|
474 |
+
You can install it with pip: `pip install tensorboard`
|
475 |
+
"""
|
476 |
+
|
477 |
+
|
478 |
+
# docstyle-ignore
|
479 |
+
COMPEL_IMPORT_ERROR = """
|
480 |
+
{0} requires the compel library but it was not found in your environment.
|
481 |
+
You can install it with pip: `pip install compel`
|
482 |
+
"""
|
483 |
+
|
484 |
+
# docstyle-ignore
|
485 |
+
BS4_IMPORT_ERROR = """
|
486 |
+
{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip:
|
487 |
+
`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation.
|
488 |
+
"""
|
489 |
+
|
490 |
+
# docstyle-ignore
|
491 |
+
FTFY_IMPORT_ERROR = """
|
492 |
+
{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the
|
493 |
+
installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones
|
494 |
+
that match your environment. Please note that you may need to restart your runtime after installation.
|
495 |
+
"""
|
496 |
+
|
497 |
+
# docstyle-ignore
|
498 |
+
TORCHSDE_IMPORT_ERROR = """
|
499 |
+
{0} requires the torchsde library but it was not found in your environment.
|
500 |
+
You can install it with pip: `pip install torchsde`
|
501 |
+
"""
|
502 |
+
|
503 |
+
|
504 |
+
BACKENDS_MAPPING = OrderedDict(
|
505 |
+
[
|
506 |
+
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
507 |
+
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
|
508 |
+
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
|
509 |
+
("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
|
510 |
+
("opencv", (is_opencv_available, OPENCV_IMPORT_ERROR)),
|
511 |
+
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
|
512 |
+
("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
|
513 |
+
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
|
514 |
+
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
515 |
+
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
|
516 |
+
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
|
517 |
+
("note_seq", (is_note_seq_available, NOTE_SEQ_IMPORT_ERROR)),
|
518 |
+
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
519 |
+
("omegaconf", (is_omegaconf_available, OMEGACONF_IMPORT_ERROR)),
|
520 |
+
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
521 |
+
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
522 |
+
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
523 |
+
("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)),
|
524 |
+
]
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
def requires_backends(obj, backends):
|
529 |
+
if not isinstance(backends, (list, tuple)):
|
530 |
+
backends = [backends]
|
531 |
+
|
532 |
+
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
533 |
+
checks = (BACKENDS_MAPPING[backend] for backend in backends)
|
534 |
+
failed = [msg.format(name) for available, msg in checks if not available()]
|
535 |
+
if failed:
|
536 |
+
raise ImportError("".join(failed))
|
537 |
+
|
538 |
+
if name in [
|
539 |
+
"VersatileDiffusionTextToImagePipeline",
|
540 |
+
"VersatileDiffusionPipeline",
|
541 |
+
"VersatileDiffusionDualGuidedPipeline",
|
542 |
+
"StableDiffusionImageVariationPipeline",
|
543 |
+
"UnCLIPPipeline",
|
544 |
+
] and is_transformers_version("<", "4.25.0"):
|
545 |
+
raise ImportError(
|
546 |
+
f"You need to install `transformers>=4.25` in order to use {name}: \n```\n pip install"
|
547 |
+
" --upgrade transformers \n```"
|
548 |
+
)
|
549 |
+
|
550 |
+
if name in ["StableDiffusionDepth2ImgPipeline", "StableDiffusionPix2PixZeroPipeline"] and is_transformers_version(
|
551 |
+
"<", "4.26.0"
|
552 |
+
):
|
553 |
+
raise ImportError(
|
554 |
+
f"You need to install `transformers>=4.26` in order to use {name}: \n```\n pip install"
|
555 |
+
" --upgrade transformers \n```"
|
556 |
+
)
|
557 |
+
|
558 |
+
|
559 |
+
class DummyObject(type):
|
560 |
+
"""
|
561 |
+
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
|
562 |
+
`requires_backend` each time a user tries to access any method of that class.
|
563 |
+
"""
|
564 |
+
|
565 |
+
def __getattr__(cls, key):
|
566 |
+
if key.startswith("_"):
|
567 |
+
return super().__getattr__(cls, key)
|
568 |
+
requires_backends(cls, cls._backends)
|
569 |
+
|
570 |
+
|
571 |
+
# This function was copied from:
|
572 |
+
# https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
|
573 |
+
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
|
574 |
+
"""
|
575 |
+
Args:
|
576 |
+
Compares a library version to some requirement using a given operation.
|
577 |
+
library_or_version (`str` or `packaging.version.Version`):
|
578 |
+
A library name or a version to check.
|
579 |
+
operation (`str`):
|
580 |
+
A string representation of an operator, such as `">"` or `"<="`.
|
581 |
+
requirement_version (`str`):
|
582 |
+
The version to compare the library version against
|
583 |
+
"""
|
584 |
+
if operation not in STR_OPERATION_TO_FUNC.keys():
|
585 |
+
raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
|
586 |
+
operation = STR_OPERATION_TO_FUNC[operation]
|
587 |
+
if isinstance(library_or_version, str):
|
588 |
+
library_or_version = parse(importlib_metadata.version(library_or_version))
|
589 |
+
return operation(library_or_version, parse(requirement_version))
|
590 |
+
|
591 |
+
|
592 |
+
# This function was copied from:
|
593 |
+
# https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
|
594 |
+
def is_torch_version(operation: str, version: str):
|
595 |
+
"""
|
596 |
+
Args:
|
597 |
+
Compares the current PyTorch version to a given reference with an operation.
|
598 |
+
operation (`str`):
|
599 |
+
A string representation of an operator, such as `">"` or `"<="`
|
600 |
+
version (`str`):
|
601 |
+
A string version of PyTorch
|
602 |
+
"""
|
603 |
+
return compare_versions(parse(_torch_version), operation, version)
|
604 |
+
|
605 |
+
|
606 |
+
def is_transformers_version(operation: str, version: str):
|
607 |
+
"""
|
608 |
+
Args:
|
609 |
+
Compares the current Transformers version to a given reference with an operation.
|
610 |
+
operation (`str`):
|
611 |
+
A string representation of an operator, such as `">"` or `"<="`
|
612 |
+
version (`str`):
|
613 |
+
A version string
|
614 |
+
"""
|
615 |
+
if not _transformers_available:
|
616 |
+
return False
|
617 |
+
return compare_versions(parse(_transformers_version), operation, version)
|
618 |
+
|
619 |
+
|
620 |
+
def is_accelerate_version(operation: str, version: str):
|
621 |
+
"""
|
622 |
+
Args:
|
623 |
+
Compares the current Accelerate version to a given reference with an operation.
|
624 |
+
operation (`str`):
|
625 |
+
A string representation of an operator, such as `">"` or `"<="`
|
626 |
+
version (`str`):
|
627 |
+
A version string
|
628 |
+
"""
|
629 |
+
if not _accelerate_available:
|
630 |
+
return False
|
631 |
+
return compare_versions(parse(_accelerate_version), operation, version)
|
632 |
+
|
633 |
+
|
634 |
+
def is_k_diffusion_version(operation: str, version: str):
|
635 |
+
"""
|
636 |
+
Args:
|
637 |
+
Compares the current k-diffusion version to a given reference with an operation.
|
638 |
+
operation (`str`):
|
639 |
+
A string representation of an operator, such as `">"` or `"<="`
|
640 |
+
version (`str`):
|
641 |
+
A version string
|
642 |
+
"""
|
643 |
+
if not _k_diffusion_available:
|
644 |
+
return False
|
645 |
+
return compare_versions(parse(_k_diffusion_version), operation, version)
|
646 |
+
|
647 |
+
|
648 |
+
class OptionalDependencyNotAvailable(BaseException):
|
649 |
+
"""An error indicating that an optional dependency of Diffusers was not found in the environment."""
|
diffusers/utils/logging.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Logging utilities."""
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import os
|
19 |
+
import sys
|
20 |
+
import threading
|
21 |
+
from logging import (
|
22 |
+
CRITICAL, # NOQA
|
23 |
+
DEBUG, # NOQA
|
24 |
+
ERROR, # NOQA
|
25 |
+
FATAL, # NOQA
|
26 |
+
INFO, # NOQA
|
27 |
+
NOTSET, # NOQA
|
28 |
+
WARN, # NOQA
|
29 |
+
WARNING, # NOQA
|
30 |
+
)
|
31 |
+
from typing import Optional
|
32 |
+
|
33 |
+
from tqdm import auto as tqdm_lib
|
34 |
+
|
35 |
+
|
36 |
+
_lock = threading.Lock()
|
37 |
+
_default_handler: Optional[logging.Handler] = None
|
38 |
+
|
39 |
+
log_levels = {
|
40 |
+
"debug": logging.DEBUG,
|
41 |
+
"info": logging.INFO,
|
42 |
+
"warning": logging.WARNING,
|
43 |
+
"error": logging.ERROR,
|
44 |
+
"critical": logging.CRITICAL,
|
45 |
+
}
|
46 |
+
|
47 |
+
_default_log_level = logging.WARNING
|
48 |
+
|
49 |
+
_tqdm_active = True
|
50 |
+
|
51 |
+
|
52 |
+
def _get_default_logging_level():
|
53 |
+
"""
|
54 |
+
If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
55 |
+
not - fall back to `_default_log_level`
|
56 |
+
"""
|
57 |
+
env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
|
58 |
+
if env_level_str:
|
59 |
+
if env_level_str in log_levels:
|
60 |
+
return log_levels[env_level_str]
|
61 |
+
else:
|
62 |
+
logging.getLogger().warning(
|
63 |
+
f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
|
64 |
+
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
65 |
+
)
|
66 |
+
return _default_log_level
|
67 |
+
|
68 |
+
|
69 |
+
def _get_library_name() -> str:
|
70 |
+
return __name__.split(".")[0]
|
71 |
+
|
72 |
+
|
73 |
+
def _get_library_root_logger() -> logging.Logger:
|
74 |
+
return logging.getLogger(_get_library_name())
|
75 |
+
|
76 |
+
|
77 |
+
def _configure_library_root_logger() -> None:
|
78 |
+
global _default_handler
|
79 |
+
|
80 |
+
with _lock:
|
81 |
+
if _default_handler:
|
82 |
+
# This library has already configured the library root logger.
|
83 |
+
return
|
84 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
85 |
+
_default_handler.flush = sys.stderr.flush
|
86 |
+
|
87 |
+
# Apply our default configuration to the library root logger.
|
88 |
+
library_root_logger = _get_library_root_logger()
|
89 |
+
library_root_logger.addHandler(_default_handler)
|
90 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
91 |
+
library_root_logger.propagate = False
|
92 |
+
|
93 |
+
|
94 |
+
def _reset_library_root_logger() -> None:
|
95 |
+
global _default_handler
|
96 |
+
|
97 |
+
with _lock:
|
98 |
+
if not _default_handler:
|
99 |
+
return
|
100 |
+
|
101 |
+
library_root_logger = _get_library_root_logger()
|
102 |
+
library_root_logger.removeHandler(_default_handler)
|
103 |
+
library_root_logger.setLevel(logging.NOTSET)
|
104 |
+
_default_handler = None
|
105 |
+
|
106 |
+
|
107 |
+
def get_log_levels_dict():
|
108 |
+
return log_levels
|
109 |
+
|
110 |
+
|
111 |
+
def get_logger(name: Optional[str] = None) -> logging.Logger:
|
112 |
+
"""
|
113 |
+
Return a logger with the specified name.
|
114 |
+
|
115 |
+
This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
|
116 |
+
"""
|
117 |
+
|
118 |
+
if name is None:
|
119 |
+
name = _get_library_name()
|
120 |
+
|
121 |
+
_configure_library_root_logger()
|
122 |
+
return logging.getLogger(name)
|
123 |
+
|
124 |
+
|
125 |
+
def get_verbosity() -> int:
|
126 |
+
"""
|
127 |
+
Return the current level for the 🤗 Diffusers' root logger as an int.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
`int`: The logging level.
|
131 |
+
|
132 |
+
<Tip>
|
133 |
+
|
134 |
+
🤗 Diffusers has following logging levels:
|
135 |
+
|
136 |
+
- 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
137 |
+
- 40: `diffusers.logging.ERROR`
|
138 |
+
- 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
139 |
+
- 20: `diffusers.logging.INFO`
|
140 |
+
- 10: `diffusers.logging.DEBUG`
|
141 |
+
|
142 |
+
</Tip>"""
|
143 |
+
|
144 |
+
_configure_library_root_logger()
|
145 |
+
return _get_library_root_logger().getEffectiveLevel()
|
146 |
+
|
147 |
+
|
148 |
+
def set_verbosity(verbosity: int) -> None:
|
149 |
+
"""
|
150 |
+
Set the verbosity level for the 🤗 Diffusers' root logger.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
verbosity (`int`):
|
154 |
+
Logging level, e.g., one of:
|
155 |
+
|
156 |
+
- `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
|
157 |
+
- `diffusers.logging.ERROR`
|
158 |
+
- `diffusers.logging.WARNING` or `diffusers.logging.WARN`
|
159 |
+
- `diffusers.logging.INFO`
|
160 |
+
- `diffusers.logging.DEBUG`
|
161 |
+
"""
|
162 |
+
|
163 |
+
_configure_library_root_logger()
|
164 |
+
_get_library_root_logger().setLevel(verbosity)
|
165 |
+
|
166 |
+
|
167 |
+
def set_verbosity_info():
|
168 |
+
"""Set the verbosity to the `INFO` level."""
|
169 |
+
return set_verbosity(INFO)
|
170 |
+
|
171 |
+
|
172 |
+
def set_verbosity_warning():
|
173 |
+
"""Set the verbosity to the `WARNING` level."""
|
174 |
+
return set_verbosity(WARNING)
|
175 |
+
|
176 |
+
|
177 |
+
def set_verbosity_debug():
|
178 |
+
"""Set the verbosity to the `DEBUG` level."""
|
179 |
+
return set_verbosity(DEBUG)
|
180 |
+
|
181 |
+
|
182 |
+
def set_verbosity_error():
|
183 |
+
"""Set the verbosity to the `ERROR` level."""
|
184 |
+
return set_verbosity(ERROR)
|
185 |
+
|
186 |
+
|
187 |
+
def disable_default_handler() -> None:
|
188 |
+
"""Disable the default handler of the HuggingFace Diffusers' root logger."""
|
189 |
+
|
190 |
+
_configure_library_root_logger()
|
191 |
+
|
192 |
+
assert _default_handler is not None
|
193 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
194 |
+
|
195 |
+
|
196 |
+
def enable_default_handler() -> None:
|
197 |
+
"""Enable the default handler of the HuggingFace Diffusers' root logger."""
|
198 |
+
|
199 |
+
_configure_library_root_logger()
|
200 |
+
|
201 |
+
assert _default_handler is not None
|
202 |
+
_get_library_root_logger().addHandler(_default_handler)
|
203 |
+
|
204 |
+
|
205 |
+
def add_handler(handler: logging.Handler) -> None:
|
206 |
+
"""adds a handler to the HuggingFace Diffusers' root logger."""
|
207 |
+
|
208 |
+
_configure_library_root_logger()
|
209 |
+
|
210 |
+
assert handler is not None
|
211 |
+
_get_library_root_logger().addHandler(handler)
|
212 |
+
|
213 |
+
|
214 |
+
def remove_handler(handler: logging.Handler) -> None:
|
215 |
+
"""removes given handler from the HuggingFace Diffusers' root logger."""
|
216 |
+
|
217 |
+
_configure_library_root_logger()
|
218 |
+
|
219 |
+
assert handler is not None and handler not in _get_library_root_logger().handlers
|
220 |
+
_get_library_root_logger().removeHandler(handler)
|
221 |
+
|
222 |
+
|
223 |
+
def disable_propagation() -> None:
|
224 |
+
"""
|
225 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
226 |
+
"""
|
227 |
+
|
228 |
+
_configure_library_root_logger()
|
229 |
+
_get_library_root_logger().propagate = False
|
230 |
+
|
231 |
+
|
232 |
+
def enable_propagation() -> None:
|
233 |
+
"""
|
234 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
|
235 |
+
double logging if the root logger has been configured.
|
236 |
+
"""
|
237 |
+
|
238 |
+
_configure_library_root_logger()
|
239 |
+
_get_library_root_logger().propagate = True
|
240 |
+
|
241 |
+
|
242 |
+
def enable_explicit_format() -> None:
|
243 |
+
"""
|
244 |
+
Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
|
245 |
+
```
|
246 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
247 |
+
```
|
248 |
+
All handlers currently bound to the root logger are affected by this method.
|
249 |
+
"""
|
250 |
+
handlers = _get_library_root_logger().handlers
|
251 |
+
|
252 |
+
for handler in handlers:
|
253 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
254 |
+
handler.setFormatter(formatter)
|
255 |
+
|
256 |
+
|
257 |
+
def reset_format() -> None:
|
258 |
+
"""
|
259 |
+
Resets the formatting for HuggingFace Diffusers' loggers.
|
260 |
+
|
261 |
+
All handlers currently bound to the root logger are affected by this method.
|
262 |
+
"""
|
263 |
+
handlers = _get_library_root_logger().handlers
|
264 |
+
|
265 |
+
for handler in handlers:
|
266 |
+
handler.setFormatter(None)
|
267 |
+
|
268 |
+
|
269 |
+
def warning_advice(self, *args, **kwargs):
|
270 |
+
"""
|
271 |
+
This method is identical to `logger.warning()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
|
272 |
+
warning will not be printed
|
273 |
+
"""
|
274 |
+
no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
|
275 |
+
if no_advisory_warnings:
|
276 |
+
return
|
277 |
+
self.warning(*args, **kwargs)
|
278 |
+
|
279 |
+
|
280 |
+
logging.Logger.warning_advice = warning_advice
|
281 |
+
|
282 |
+
|
283 |
+
class EmptyTqdm:
|
284 |
+
"""Dummy tqdm which doesn't do anything."""
|
285 |
+
|
286 |
+
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
287 |
+
self._iterator = args[0] if args else None
|
288 |
+
|
289 |
+
def __iter__(self):
|
290 |
+
return iter(self._iterator)
|
291 |
+
|
292 |
+
def __getattr__(self, _):
|
293 |
+
"""Return empty function."""
|
294 |
+
|
295 |
+
def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
|
296 |
+
return
|
297 |
+
|
298 |
+
return empty_fn
|
299 |
+
|
300 |
+
def __enter__(self):
|
301 |
+
return self
|
302 |
+
|
303 |
+
def __exit__(self, type_, value, traceback):
|
304 |
+
return
|
305 |
+
|
306 |
+
|
307 |
+
class _tqdm_cls:
|
308 |
+
def __call__(self, *args, **kwargs):
|
309 |
+
if _tqdm_active:
|
310 |
+
return tqdm_lib.tqdm(*args, **kwargs)
|
311 |
+
else:
|
312 |
+
return EmptyTqdm(*args, **kwargs)
|
313 |
+
|
314 |
+
def set_lock(self, *args, **kwargs):
|
315 |
+
self._lock = None
|
316 |
+
if _tqdm_active:
|
317 |
+
return tqdm_lib.tqdm.set_lock(*args, **kwargs)
|
318 |
+
|
319 |
+
def get_lock(self):
|
320 |
+
if _tqdm_active:
|
321 |
+
return tqdm_lib.tqdm.get_lock()
|
322 |
+
|
323 |
+
|
324 |
+
tqdm = _tqdm_cls()
|
325 |
+
|
326 |
+
|
327 |
+
def is_progress_bar_enabled() -> bool:
|
328 |
+
"""Return a boolean indicating whether tqdm progress bars are enabled."""
|
329 |
+
global _tqdm_active
|
330 |
+
return bool(_tqdm_active)
|
331 |
+
|
332 |
+
|
333 |
+
def enable_progress_bar():
|
334 |
+
"""Enable tqdm progress bar."""
|
335 |
+
global _tqdm_active
|
336 |
+
_tqdm_active = True
|
337 |
+
|
338 |
+
|
339 |
+
def disable_progress_bar():
|
340 |
+
"""Disable tqdm progress bar."""
|
341 |
+
global _tqdm_active
|
342 |
+
_tqdm_active = False
|
diffusers/utils/outputs.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Generic utilities
|
16 |
+
"""
|
17 |
+
|
18 |
+
from collections import OrderedDict
|
19 |
+
from dataclasses import fields
|
20 |
+
from typing import Any, Tuple
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
from .import_utils import is_torch_available
|
25 |
+
|
26 |
+
|
27 |
+
def is_tensor(x):
|
28 |
+
"""
|
29 |
+
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
30 |
+
"""
|
31 |
+
if is_torch_available():
|
32 |
+
import torch
|
33 |
+
|
34 |
+
if isinstance(x, torch.Tensor):
|
35 |
+
return True
|
36 |
+
|
37 |
+
return isinstance(x, np.ndarray)
|
38 |
+
|
39 |
+
|
40 |
+
class BaseOutput(OrderedDict):
|
41 |
+
"""
|
42 |
+
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
|
43 |
+
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
|
44 |
+
python dictionary.
|
45 |
+
|
46 |
+
<Tip warning={true}>
|
47 |
+
|
48 |
+
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
|
49 |
+
before.
|
50 |
+
|
51 |
+
</Tip>
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __post_init__(self):
|
55 |
+
class_fields = fields(self)
|
56 |
+
|
57 |
+
# Safety and consistency checks
|
58 |
+
if not len(class_fields):
|
59 |
+
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
60 |
+
|
61 |
+
first_field = getattr(self, class_fields[0].name)
|
62 |
+
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
63 |
+
|
64 |
+
if other_fields_are_none and isinstance(first_field, dict):
|
65 |
+
for key, value in first_field.items():
|
66 |
+
self[key] = value
|
67 |
+
else:
|
68 |
+
for field in class_fields:
|
69 |
+
v = getattr(self, field.name)
|
70 |
+
if v is not None:
|
71 |
+
self[field.name] = v
|
72 |
+
|
73 |
+
def __delitem__(self, *args, **kwargs):
|
74 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
75 |
+
|
76 |
+
def setdefault(self, *args, **kwargs):
|
77 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
78 |
+
|
79 |
+
def pop(self, *args, **kwargs):
|
80 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
81 |
+
|
82 |
+
def update(self, *args, **kwargs):
|
83 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
84 |
+
|
85 |
+
def __getitem__(self, k):
|
86 |
+
if isinstance(k, str):
|
87 |
+
inner_dict = dict(self.items())
|
88 |
+
return inner_dict[k]
|
89 |
+
else:
|
90 |
+
return self.to_tuple()[k]
|
91 |
+
|
92 |
+
def __setattr__(self, name, value):
|
93 |
+
if name in self.keys() and value is not None:
|
94 |
+
# Don't call self.__setitem__ to avoid recursion errors
|
95 |
+
super().__setitem__(name, value)
|
96 |
+
super().__setattr__(name, value)
|
97 |
+
|
98 |
+
def __setitem__(self, key, value):
|
99 |
+
# Will raise a KeyException if needed
|
100 |
+
super().__setitem__(key, value)
|
101 |
+
# Don't call self.__setattr__ to avoid recursion errors
|
102 |
+
super().__setattr__(key, value)
|
103 |
+
|
104 |
+
def to_tuple(self) -> Tuple[Any]:
|
105 |
+
"""
|
106 |
+
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
107 |
+
"""
|
108 |
+
return tuple(self[k] for k in self.keys())
|
diffusers/utils/scheduling_utils.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import importlib
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from enum import Enum
|
18 |
+
from typing import Any, Dict, Optional, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from .outputs import BaseOutput
|
23 |
+
|
24 |
+
|
25 |
+
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
26 |
+
|
27 |
+
|
28 |
+
# NOTE: We make this type an enum because it simplifies usage in docs and prevents
|
29 |
+
# circular imports when used for `_compatibles` within the schedulers module.
|
30 |
+
# When it's used as a type in pipelines, it really is a Union because the actual
|
31 |
+
# scheduler instance is passed in.
|
32 |
+
class KarrasDiffusionSchedulers(Enum):
|
33 |
+
DDIMScheduler = 1
|
34 |
+
DDPMScheduler = 2
|
35 |
+
PNDMScheduler = 3
|
36 |
+
LMSDiscreteScheduler = 4
|
37 |
+
EulerDiscreteScheduler = 5
|
38 |
+
HeunDiscreteScheduler = 6
|
39 |
+
EulerAncestralDiscreteScheduler = 7
|
40 |
+
DPMSolverMultistepScheduler = 8
|
41 |
+
DPMSolverSinglestepScheduler = 9
|
42 |
+
KDPM2DiscreteScheduler = 10
|
43 |
+
KDPM2AncestralDiscreteScheduler = 11
|
44 |
+
DEISMultistepScheduler = 12
|
45 |
+
UniPCMultistepScheduler = 13
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class SchedulerOutput(BaseOutput):
|
50 |
+
"""
|
51 |
+
Base class for the scheduler's step function output.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
55 |
+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
56 |
+
denoising loop.
|
57 |
+
"""
|
58 |
+
|
59 |
+
prev_sample: torch.FloatTensor
|
60 |
+
|
61 |
+
|
62 |
+
class SchedulerMixin:
|
63 |
+
"""
|
64 |
+
Mixin containing common functions for the schedulers.
|
65 |
+
|
66 |
+
Class attributes:
|
67 |
+
- **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
|
68 |
+
`from_config` can be used from a class different than the one used to save the config (should be overridden
|
69 |
+
by parent class).
|
70 |
+
"""
|
71 |
+
|
72 |
+
config_name = SCHEDULER_CONFIG_NAME
|
73 |
+
_compatibles = []
|
74 |
+
has_compatibles = True
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def from_pretrained(
|
78 |
+
cls,
|
79 |
+
pretrained_model_name_or_path: Dict[str, Any] = None,
|
80 |
+
subfolder: Optional[str] = None,
|
81 |
+
return_unused_kwargs=False,
|
82 |
+
**kwargs,
|
83 |
+
):
|
84 |
+
r"""
|
85 |
+
Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo.
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
89 |
+
Can be either:
|
90 |
+
|
91 |
+
- A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
|
92 |
+
organization name, like `google/ddpm-celebahq-256`.
|
93 |
+
- A path to a *directory* containing the schedluer configurations saved using
|
94 |
+
[`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`.
|
95 |
+
subfolder (`str`, *optional*):
|
96 |
+
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
97 |
+
huggingface.co or downloaded locally), you can specify the folder name here.
|
98 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
99 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
100 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
101 |
+
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
102 |
+
standard cache should not be used.
|
103 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
104 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
105 |
+
cached versions if they exist.
|
106 |
+
resume_download (`bool`, *optional*, defaults to `False`):
|
107 |
+
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
108 |
+
file exists.
|
109 |
+
proxies (`Dict[str, str]`, *optional*):
|
110 |
+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
111 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
112 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
113 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
114 |
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
115 |
+
Whether or not to only look at local files (i.e., do not try to download the model).
|
116 |
+
use_auth_token (`str` or *bool*, *optional*):
|
117 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
118 |
+
when running `transformers-cli login` (stored in `~/.huggingface`).
|
119 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
120 |
+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
121 |
+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
122 |
+
identifier allowed by git.
|
123 |
+
|
124 |
+
<Tip>
|
125 |
+
|
126 |
+
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
127 |
+
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
128 |
+
|
129 |
+
</Tip>
|
130 |
+
|
131 |
+
<Tip>
|
132 |
+
|
133 |
+
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
134 |
+
use this method in a firewalled environment.
|
135 |
+
|
136 |
+
</Tip>
|
137 |
+
|
138 |
+
"""
|
139 |
+
config, kwargs, commit_hash = cls.load_config(
|
140 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
141 |
+
subfolder=subfolder,
|
142 |
+
return_unused_kwargs=True,
|
143 |
+
return_commit_hash=True,
|
144 |
+
**kwargs,
|
145 |
+
)
|
146 |
+
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
147 |
+
|
148 |
+
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
149 |
+
"""
|
150 |
+
Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
151 |
+
[`~SchedulerMixin.from_pretrained`] class method.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
save_directory (`str` or `os.PathLike`):
|
155 |
+
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
156 |
+
"""
|
157 |
+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
158 |
+
|
159 |
+
@property
|
160 |
+
def compatibles(self):
|
161 |
+
"""
|
162 |
+
Returns all schedulers that are compatible with this scheduler
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
`List[SchedulerMixin]`: List of compatible schedulers
|
166 |
+
"""
|
167 |
+
return self._get_compatibles()
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def _get_compatibles(cls):
|
171 |
+
compatible_classes_str = list(set([cls.__name__] + cls._compatibles))
|
172 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
173 |
+
compatible_classes = [
|
174 |
+
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
175 |
+
]
|
176 |
+
return compatible_classes
|
diffusers/utils/torch_utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
PyTorch utilities: Utilities related to PyTorch
|
16 |
+
"""
|
17 |
+
from typing import List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
from . import logging
|
20 |
+
from .import_utils import is_torch_available, is_torch_version
|
21 |
+
|
22 |
+
if is_torch_available():
|
23 |
+
import torch
|
24 |
+
|
25 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
26 |
+
|
27 |
+
try:
|
28 |
+
from torch._dynamo import allow_in_graph as maybe_allow_in_graph
|
29 |
+
except (ImportError, ModuleNotFoundError):
|
30 |
+
|
31 |
+
def maybe_allow_in_graph(cls):
|
32 |
+
return cls
|
33 |
+
|
34 |
+
|
35 |
+
def randn_tensor(
|
36 |
+
shape: Union[Tuple, List],
|
37 |
+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
38 |
+
device: Optional["torch.device"] = None,
|
39 |
+
dtype: Optional["torch.dtype"] = None,
|
40 |
+
layout: Optional["torch.layout"] = None,
|
41 |
+
):
|
42 |
+
"""This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When
|
43 |
+
passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor
|
44 |
+
will always be created on CPU.
|
45 |
+
"""
|
46 |
+
# device on which tensor is created defaults to device
|
47 |
+
rand_device = device
|
48 |
+
batch_size = shape[0]
|
49 |
+
|
50 |
+
layout = layout or torch.strided
|
51 |
+
device = device or torch.device("cpu")
|
52 |
+
|
53 |
+
if generator is not None:
|
54 |
+
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
|
55 |
+
if gen_device_type != device.type and gen_device_type == "cpu":
|
56 |
+
rand_device = "cpu"
|
57 |
+
if device != "mps":
|
58 |
+
logger.info(
|
59 |
+
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
|
60 |
+
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
|
61 |
+
f" slighly speed up this function by passing a generator that was created on the {device} device."
|
62 |
+
)
|
63 |
+
elif gen_device_type != device.type and gen_device_type == "cuda":
|
64 |
+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
65 |
+
|
66 |
+
if isinstance(generator, list):
|
67 |
+
shape = (1,) + shape[1:]
|
68 |
+
latents = [
|
69 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
70 |
+
for i in range(batch_size)
|
71 |
+
]
|
72 |
+
latents = torch.cat(latents, dim=0).to(device)
|
73 |
+
else:
|
74 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
75 |
+
|
76 |
+
return latents
|
77 |
+
|
78 |
+
|
79 |
+
def is_compiled_module(module):
|
80 |
+
"""Check whether the module was compiled with torch.compile()"""
|
81 |
+
if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"):
|
82 |
+
return False
|
83 |
+
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
run_gradio.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
import soundfile as sf
|
4 |
+
import numpy as np
|
5 |
+
import random, os
|
6 |
+
|
7 |
+
from consistencytta import ConsistencyTTA
|
8 |
+
|
9 |
+
|
10 |
+
def seed_all(seed):
|
11 |
+
""" Seed all random number generators. """
|
12 |
+
seed = int(seed)
|
13 |
+
random.seed(seed)
|
14 |
+
np.random.seed(seed)
|
15 |
+
torch.manual_seed(seed)
|
16 |
+
torch.cuda.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed_all(seed)
|
18 |
+
torch.cuda.random.manual_seed(seed)
|
19 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
20 |
+
torch.backends.cudnn.benchmark = False
|
21 |
+
torch.backends.cudnn.deterministic = True
|
22 |
+
|
23 |
+
|
24 |
+
device = torch.device(
|
25 |
+
"cuda:0" if torch.cuda.is_available() else
|
26 |
+
"mps" if torch.backends.mps.is_available() else "cpu"
|
27 |
+
)
|
28 |
+
sr = 16000
|
29 |
+
|
30 |
+
# Build ConsistencyTTA model
|
31 |
+
consistencytta = ConsistencyTTA().to(device)
|
32 |
+
consistencytta.eval()
|
33 |
+
consistencytta.requires_grad_(False)
|
34 |
+
|
35 |
+
|
36 |
+
def generate(prompt: str, seed: str = '', cfg_weight: float = 4.):
|
37 |
+
""" Generate audio from a given prompt.
|
38 |
+
Args:
|
39 |
+
prompt (str): Text prompt to generate audio from.
|
40 |
+
seed (str, optional): Random seed. Defaults to '', which means no seed.
|
41 |
+
"""
|
42 |
+
if seed != '':
|
43 |
+
try:
|
44 |
+
seed_all(int(seed))
|
45 |
+
except:
|
46 |
+
pass
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
with torch.autocast(
|
50 |
+
device_type="cuda", dtype=torch.bfloat16, enabled=torch.cuda.is_available()
|
51 |
+
):
|
52 |
+
wav = consistencytta(
|
53 |
+
[prompt], num_steps=1, cfg_scale_input=cfg_weight, cfg_scale_post=1., sr=sr
|
54 |
+
)
|
55 |
+
sf.write("output.wav", wav.T, samplerate=sr, subtype='PCM_16')
|
56 |
+
|
57 |
+
return "output.wav"
|
58 |
+
|
59 |
+
|
60 |
+
# Generate test audio
|
61 |
+
print("Generating test audio...")
|
62 |
+
generate("A dog barks as a train passes by.", seed=1)
|
63 |
+
print("Test audio generated successfully! Starting Gradio interface...")
|
64 |
+
|
65 |
+
# Launch Gradio interface
|
66 |
+
iface = gr.Interface(
|
67 |
+
fn=generate,
|
68 |
+
inputs=[
|
69 |
+
gr.Textbox(
|
70 |
+
label="Text", value="Several people cheer and scream and speak as water flows hard."
|
71 |
+
),
|
72 |
+
gr.Textbox(label="Random Seed (Optional)", value=''),
|
73 |
+
gr.Slider(
|
74 |
+
minimum=0., maximum=8., value=3.5, label="Classifier-Free Guidance Strength"
|
75 |
+
)],
|
76 |
+
outputs="audio",
|
77 |
+
title="ConsistencyTTA: Accelerating Diffusion-Based Text-to-Audio " \
|
78 |
+
"Generation with Consistency Distillation",
|
79 |
+
description="This is the official demo page for <a href='https://consistency-tta.github." \
|
80 |
+
"io' target=“blank”>ConsistencyTTA</a>, a model that accelerates " \
|
81 |
+
"diffusion-based text-to-audio generation hundreds of times with consistency " \
|
82 |
+
"models. <br> Here, the audio is generated within a single non-autoregressive " \
|
83 |
+
"forward pass from the CLAP-finetuned ConsistencyTTA checkpoint. <br> Since " \
|
84 |
+
"the training dataset does not include speech, the model is not expected to " \
|
85 |
+
"generate coherent speech. <br> Have fun!"
|
86 |
+
)
|
87 |
+
iface.launch(share=True)
|
tango_diffusion_light.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.10.0.dev0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": [
|
6 |
+
5,
|
7 |
+
10,
|
8 |
+
20,
|
9 |
+
20
|
10 |
+
],
|
11 |
+
"block_out_channels": [
|
12 |
+
256,
|
13 |
+
512,
|
14 |
+
1024,
|
15 |
+
1024
|
16 |
+
],
|
17 |
+
"center_input_sample": false,
|
18 |
+
"cross_attention_dim": 1024,
|
19 |
+
"down_block_types": [
|
20 |
+
"CrossAttnDownBlock2D",
|
21 |
+
"CrossAttnDownBlock2D",
|
22 |
+
"CrossAttnDownBlock2D",
|
23 |
+
"DownBlock2D"
|
24 |
+
],
|
25 |
+
"downsample_padding": 1,
|
26 |
+
"dual_cross_attention": false,
|
27 |
+
"flip_sin_to_cos": true,
|
28 |
+
"freq_shift": 0,
|
29 |
+
"in_channels": 8,
|
30 |
+
"layers_per_block": 2,
|
31 |
+
"mid_block_scale_factor": 1,
|
32 |
+
"norm_eps": 1e-05,
|
33 |
+
"norm_num_groups": 32,
|
34 |
+
"num_class_embeds": null,
|
35 |
+
"only_cross_attention": false,
|
36 |
+
"out_channels": 8,
|
37 |
+
"sample_size": [32, 2],
|
38 |
+
"up_block_types": [
|
39 |
+
"UpBlock2D",
|
40 |
+
"CrossAttnUpBlock2D",
|
41 |
+
"CrossAttnUpBlock2D",
|
42 |
+
"CrossAttnUpBlock2D"
|
43 |
+
],
|
44 |
+
"use_linear_projection": true,
|
45 |
+
"upcast_attention": true
|
46 |
+
}
|