lora
Browse files- LoRA-EXTRACTOR +1 -0
- Lora/RUNTHIS.txt +1 -0
- Lora/lib/FileToOpen.exe +0 -0
- Lora/lib/FileToSave.exe +0 -0
- Lora/lib/__pycache__/lora.cpython-39.pyc +0 -0
- Lora/lib/__pycache__/model_util.cpython-39.pyc +0 -0
- Lora/lib/__pycache__/train_util.cpython-39.pyc +0 -0
- Lora/lib/extract_lora_from_models.py +164 -0
- Lora/lib/lora.py +237 -0
- Lora/lib/model_util.py +1180 -0
- Lora/lib/qwerty.py +91 -0
- Lora/lib/train_util.py +1766 -0
- Lora/matous_LORA.safetensors +3 -0
LoRA-EXTRACTOR
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit fa8e800cf5ab8a87f4cf449507b79c05d190088d
|
Lora/RUNTHIS.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
!python /content/stable-diffusion-webui/matousbecvar_sd1.5_drbth/LoRA-EXTRACTOR/lib/extract_lora_from_models.py --save_precision fp16 --save_to matous_LORA.safetensors --model_org /content/stable-diffusion-webui/models/Stable-diffusion/stable_diffusion_v1_5.safetensors --model_tuned /content/stable-diffusion-webui/matousbecvar_sd1.5_drbth/matousbecvar.safetensors --dim 128
|
Lora/lib/FileToOpen.exe
ADDED
Binary file (16.4 kB). View file
|
|
Lora/lib/FileToSave.exe
ADDED
Binary file (15.9 kB). View file
|
|
Lora/lib/__pycache__/lora.cpython-39.pyc
ADDED
Binary file (7.38 kB). View file
|
|
Lora/lib/__pycache__/model_util.cpython-39.pyc
ADDED
Binary file (29.7 kB). View file
|
|
Lora/lib/__pycache__/train_util.cpython-39.pyc
ADDED
Binary file (56.4 kB). View file
|
|
Lora/lib/extract_lora_from_models.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# extract approximating LoRA by svd from two SD models
|
2 |
+
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
3 |
+
# Thanks to cloneofsimo!
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from safetensors.torch import load_file, save_file
|
9 |
+
from tqdm import tqdm
|
10 |
+
import model_util
|
11 |
+
import lora
|
12 |
+
|
13 |
+
|
14 |
+
CLAMP_QUANTILE = 0.99
|
15 |
+
MIN_DIFF = 1e-6
|
16 |
+
|
17 |
+
|
18 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
19 |
+
if dtype is not None:
|
20 |
+
for key in list(state_dict.keys()):
|
21 |
+
if type(state_dict[key]) == torch.Tensor:
|
22 |
+
state_dict[key] = state_dict[key].to(dtype)
|
23 |
+
|
24 |
+
if os.path.splitext(file_name)[1] == '.safetensors':
|
25 |
+
save_file(model, file_name)
|
26 |
+
else:
|
27 |
+
torch.save(model, file_name)
|
28 |
+
|
29 |
+
|
30 |
+
def svd(args):
|
31 |
+
def str_to_dtype(p):
|
32 |
+
if p == 'float':
|
33 |
+
return torch.float
|
34 |
+
if p == 'fp16':
|
35 |
+
return torch.float16
|
36 |
+
if p == 'bf16':
|
37 |
+
return torch.bfloat16
|
38 |
+
return None
|
39 |
+
|
40 |
+
save_dtype = str_to_dtype(args.save_precision)
|
41 |
+
|
42 |
+
print(f"loading SD model : {args.model_org}")
|
43 |
+
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
44 |
+
print(f"loading SD model : {args.model_tuned}")
|
45 |
+
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
46 |
+
|
47 |
+
# create LoRA network to extract weights: Use dim (rank) as alpha
|
48 |
+
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
49 |
+
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
50 |
+
assert len(lora_network_o.text_encoder_loras) == len(
|
51 |
+
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
52 |
+
|
53 |
+
# get diffs
|
54 |
+
diffs = {}
|
55 |
+
text_encoder_different = False
|
56 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
57 |
+
lora_name = lora_o.lora_name
|
58 |
+
module_o = lora_o.org_module
|
59 |
+
module_t = lora_t.org_module
|
60 |
+
diff = module_t.weight - module_o.weight
|
61 |
+
|
62 |
+
# Text Encoder might be same
|
63 |
+
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
64 |
+
text_encoder_different = True
|
65 |
+
|
66 |
+
diff = diff.float()
|
67 |
+
diffs[lora_name] = diff
|
68 |
+
|
69 |
+
if not text_encoder_different:
|
70 |
+
print("Text encoder is same. Extract U-Net only.")
|
71 |
+
lora_network_o.text_encoder_loras = []
|
72 |
+
diffs = {}
|
73 |
+
|
74 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
75 |
+
lora_name = lora_o.lora_name
|
76 |
+
module_o = lora_o.org_module
|
77 |
+
module_t = lora_t.org_module
|
78 |
+
diff = module_t.weight - module_o.weight
|
79 |
+
diff = diff.float()
|
80 |
+
|
81 |
+
if args.device:
|
82 |
+
diff = diff.to(args.device)
|
83 |
+
|
84 |
+
diffs[lora_name] = diff
|
85 |
+
|
86 |
+
# make LoRA with svd
|
87 |
+
print("calculating by svd")
|
88 |
+
rank = args.dim
|
89 |
+
lora_weights = {}
|
90 |
+
with torch.no_grad():
|
91 |
+
for lora_name, mat in tqdm(list(diffs.items())):
|
92 |
+
conv2d = (len(mat.size()) == 4)
|
93 |
+
if conv2d:
|
94 |
+
mat = mat.squeeze()
|
95 |
+
|
96 |
+
U, S, Vh = torch.linalg.svd(mat)
|
97 |
+
|
98 |
+
U = U[:, :rank]
|
99 |
+
S = S[:rank]
|
100 |
+
U = U @ torch.diag(S)
|
101 |
+
|
102 |
+
Vh = Vh[:rank, :]
|
103 |
+
|
104 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
105 |
+
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
106 |
+
low_val = -hi_val
|
107 |
+
|
108 |
+
U = U.clamp(low_val, hi_val)
|
109 |
+
Vh = Vh.clamp(low_val, hi_val)
|
110 |
+
|
111 |
+
lora_weights[lora_name] = (U, Vh)
|
112 |
+
|
113 |
+
# make state dict for LoRA
|
114 |
+
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
115 |
+
lora_sd = lora_network_o.state_dict()
|
116 |
+
print(f"LoRA has {len(lora_sd)} weights.")
|
117 |
+
|
118 |
+
for key in list(lora_sd.keys()):
|
119 |
+
if "alpha" in key:
|
120 |
+
continue
|
121 |
+
|
122 |
+
lora_name = key.split('.')[0]
|
123 |
+
i = 0 if "lora_up" in key else 1
|
124 |
+
|
125 |
+
weights = lora_weights[lora_name][i]
|
126 |
+
# print(key, i, weights.size(), lora_sd[key].size())
|
127 |
+
if len(lora_sd[key].size()) == 4:
|
128 |
+
weights = weights.unsqueeze(2).unsqueeze(3)
|
129 |
+
|
130 |
+
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
131 |
+
lora_sd[key] = weights
|
132 |
+
|
133 |
+
# load state dict to LoRA and save it
|
134 |
+
info = lora_network_o.load_state_dict(lora_sd)
|
135 |
+
print(f"Loading extracted LoRA weights: {info}")
|
136 |
+
|
137 |
+
dir_name = os.path.dirname(args.save_to)
|
138 |
+
if dir_name and not os.path.exists(dir_name):
|
139 |
+
os.makedirs(dir_name, exist_ok=True)
|
140 |
+
|
141 |
+
# minimum metadata
|
142 |
+
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
143 |
+
|
144 |
+
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
145 |
+
print(f"LoRA weights are saved to: {args.save_to}")
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
parser = argparse.ArgumentParser()
|
150 |
+
parser.add_argument("--v2", action='store_true',
|
151 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
152 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
153 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
|
154 |
+
parser.add_argument("--model_org", type=str, default=None,
|
155 |
+
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
|
156 |
+
parser.add_argument("--model_tuned", type=str, default=None,
|
157 |
+
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
158 |
+
parser.add_argument("--save_to", type=str, default=None,
|
159 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
160 |
+
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
161 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
svd(args)
|
Lora/lib/lora.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LoRA network module
|
2 |
+
# reference:
|
3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from typing import List
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import train_util
|
12 |
+
|
13 |
+
|
14 |
+
class LoRAModule(torch.nn.Module):
|
15 |
+
"""
|
16 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
20 |
+
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
21 |
+
super().__init__()
|
22 |
+
self.lora_name = lora_name
|
23 |
+
self.lora_dim = lora_dim
|
24 |
+
|
25 |
+
if org_module.__class__.__name__ == 'Conv2d':
|
26 |
+
in_dim = org_module.in_channels
|
27 |
+
out_dim = org_module.out_channels
|
28 |
+
self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
|
29 |
+
self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
|
30 |
+
else:
|
31 |
+
in_dim = org_module.in_features
|
32 |
+
out_dim = org_module.out_features
|
33 |
+
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
34 |
+
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
35 |
+
|
36 |
+
if type(alpha) == torch.Tensor:
|
37 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
38 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
39 |
+
self.scale = alpha / self.lora_dim
|
40 |
+
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
41 |
+
|
42 |
+
# same as microsoft's
|
43 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
44 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
45 |
+
|
46 |
+
self.multiplier = multiplier
|
47 |
+
self.org_module = org_module # remove in applying
|
48 |
+
|
49 |
+
def apply_to(self):
|
50 |
+
self.org_forward = self.org_module.forward
|
51 |
+
self.org_module.forward = self.forward
|
52 |
+
del self.org_module
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
56 |
+
|
57 |
+
|
58 |
+
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
59 |
+
if network_dim is None:
|
60 |
+
network_dim = 4 # default
|
61 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
62 |
+
return network
|
63 |
+
|
64 |
+
|
65 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
66 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
67 |
+
from safetensors.torch import load_file, safe_open
|
68 |
+
weights_sd = load_file(file)
|
69 |
+
else:
|
70 |
+
weights_sd = torch.load(file, map_location='cpu')
|
71 |
+
|
72 |
+
# get dim (rank)
|
73 |
+
network_alpha = None
|
74 |
+
network_dim = None
|
75 |
+
for key, value in weights_sd.items():
|
76 |
+
if network_alpha is None and 'alpha' in key:
|
77 |
+
network_alpha = value
|
78 |
+
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
79 |
+
network_dim = value.size()[0]
|
80 |
+
|
81 |
+
if network_alpha is None:
|
82 |
+
network_alpha = network_dim
|
83 |
+
|
84 |
+
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
85 |
+
network.weights_sd = weights_sd
|
86 |
+
return network
|
87 |
+
|
88 |
+
|
89 |
+
class LoRANetwork(torch.nn.Module):
|
90 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
91 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
92 |
+
LORA_PREFIX_UNET = 'lora_unet'
|
93 |
+
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
94 |
+
|
95 |
+
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
96 |
+
super().__init__()
|
97 |
+
self.multiplier = multiplier
|
98 |
+
self.lora_dim = lora_dim
|
99 |
+
self.alpha = alpha
|
100 |
+
|
101 |
+
# create module instances
|
102 |
+
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
103 |
+
loras = []
|
104 |
+
for name, module in root_module.named_modules():
|
105 |
+
if module.__class__.__name__ in target_replace_modules:
|
106 |
+
for child_name, child_module in module.named_modules():
|
107 |
+
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
108 |
+
lora_name = prefix + '.' + name + '.' + child_name
|
109 |
+
lora_name = lora_name.replace('.', '_')
|
110 |
+
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
111 |
+
loras.append(lora)
|
112 |
+
return loras
|
113 |
+
|
114 |
+
self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER,
|
115 |
+
text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
116 |
+
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
117 |
+
|
118 |
+
self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
|
119 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
120 |
+
|
121 |
+
self.weights_sd = None
|
122 |
+
|
123 |
+
# assertion
|
124 |
+
names = set()
|
125 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
126 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
127 |
+
names.add(lora.lora_name)
|
128 |
+
|
129 |
+
def load_weights(self, file):
|
130 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
131 |
+
from safetensors.torch import load_file, safe_open
|
132 |
+
self.weights_sd = load_file(file)
|
133 |
+
else:
|
134 |
+
self.weights_sd = torch.load(file, map_location='cpu')
|
135 |
+
|
136 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None):
|
137 |
+
if self.weights_sd:
|
138 |
+
weights_has_text_encoder = weights_has_unet = False
|
139 |
+
for key in self.weights_sd.keys():
|
140 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
141 |
+
weights_has_text_encoder = True
|
142 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
143 |
+
weights_has_unet = True
|
144 |
+
|
145 |
+
if apply_text_encoder is None:
|
146 |
+
apply_text_encoder = weights_has_text_encoder
|
147 |
+
else:
|
148 |
+
assert apply_text_encoder == weights_has_text_encoder, f"text encoder weights: {weights_has_text_encoder} but text encoder flag: {apply_text_encoder} / 重みとText Encoderのフラグが矛盾しています"
|
149 |
+
|
150 |
+
if apply_unet is None:
|
151 |
+
apply_unet = weights_has_unet
|
152 |
+
else:
|
153 |
+
assert apply_unet == weights_has_unet, f"u-net weights: {weights_has_unet} but u-net flag: {apply_unet} / 重みとU-Netのフラグが矛盾しています"
|
154 |
+
else:
|
155 |
+
assert apply_text_encoder is not None and apply_unet is not None, f"internal error: flag not set"
|
156 |
+
|
157 |
+
if apply_text_encoder:
|
158 |
+
print("enable LoRA for text encoder")
|
159 |
+
else:
|
160 |
+
self.text_encoder_loras = []
|
161 |
+
|
162 |
+
if apply_unet:
|
163 |
+
print("enable LoRA for U-Net")
|
164 |
+
else:
|
165 |
+
self.unet_loras = []
|
166 |
+
|
167 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
168 |
+
lora.apply_to()
|
169 |
+
self.add_module(lora.lora_name, lora)
|
170 |
+
|
171 |
+
if self.weights_sd:
|
172 |
+
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
|
173 |
+
info = self.load_state_dict(self.weights_sd, False)
|
174 |
+
print(f"weights are loaded: {info}")
|
175 |
+
|
176 |
+
def enable_gradient_checkpointing(self):
|
177 |
+
# not supported
|
178 |
+
pass
|
179 |
+
|
180 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
|
181 |
+
def enumerate_params(loras):
|
182 |
+
params = []
|
183 |
+
for lora in loras:
|
184 |
+
params.extend(lora.parameters())
|
185 |
+
return params
|
186 |
+
|
187 |
+
self.requires_grad_(True)
|
188 |
+
all_params = []
|
189 |
+
|
190 |
+
if self.text_encoder_loras:
|
191 |
+
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
192 |
+
if text_encoder_lr is not None:
|
193 |
+
param_data['lr'] = text_encoder_lr
|
194 |
+
all_params.append(param_data)
|
195 |
+
|
196 |
+
if self.unet_loras:
|
197 |
+
param_data = {'params': enumerate_params(self.unet_loras)}
|
198 |
+
if unet_lr is not None:
|
199 |
+
param_data['lr'] = unet_lr
|
200 |
+
all_params.append(param_data)
|
201 |
+
|
202 |
+
return all_params
|
203 |
+
|
204 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
205 |
+
self.requires_grad_(True)
|
206 |
+
|
207 |
+
def on_epoch_start(self, text_encoder, unet):
|
208 |
+
self.train()
|
209 |
+
|
210 |
+
def get_trainable_params(self):
|
211 |
+
return self.parameters()
|
212 |
+
|
213 |
+
def save_weights(self, file, dtype, metadata):
|
214 |
+
if metadata is not None and len(metadata) == 0:
|
215 |
+
metadata = None
|
216 |
+
|
217 |
+
state_dict = self.state_dict()
|
218 |
+
|
219 |
+
if dtype is not None:
|
220 |
+
for key in list(state_dict.keys()):
|
221 |
+
v = state_dict[key]
|
222 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
223 |
+
state_dict[key] = v
|
224 |
+
|
225 |
+
if os.path.splitext(file)[1] == '.safetensors':
|
226 |
+
from safetensors.torch import save_file
|
227 |
+
|
228 |
+
# Precalculate model hashes to save time on indexing
|
229 |
+
if metadata is None:
|
230 |
+
metadata = {}
|
231 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
232 |
+
metadata["sshs_model_hash"] = model_hash
|
233 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
234 |
+
|
235 |
+
save_file(state_dict, file, metadata)
|
236 |
+
else:
|
237 |
+
torch.save(state_dict, file)
|
Lora/lib/model_util.py
ADDED
@@ -0,0 +1,1180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# v1: split from train_db_fixed.py.
|
2 |
+
# v2: support safetensors
|
3 |
+
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
8 |
+
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
9 |
+
from safetensors.torch import load_file, save_file
|
10 |
+
|
11 |
+
# DiffUsers版StableDiffusionのモデルパラメータ
|
12 |
+
NUM_TRAIN_TIMESTEPS = 1000
|
13 |
+
BETA_START = 0.00085
|
14 |
+
BETA_END = 0.0120
|
15 |
+
|
16 |
+
UNET_PARAMS_MODEL_CHANNELS = 320
|
17 |
+
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
18 |
+
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
19 |
+
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
20 |
+
UNET_PARAMS_IN_CHANNELS = 4
|
21 |
+
UNET_PARAMS_OUT_CHANNELS = 4
|
22 |
+
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
23 |
+
UNET_PARAMS_CONTEXT_DIM = 768
|
24 |
+
UNET_PARAMS_NUM_HEADS = 8
|
25 |
+
|
26 |
+
VAE_PARAMS_Z_CHANNELS = 4
|
27 |
+
VAE_PARAMS_RESOLUTION = 256
|
28 |
+
VAE_PARAMS_IN_CHANNELS = 3
|
29 |
+
VAE_PARAMS_OUT_CH = 3
|
30 |
+
VAE_PARAMS_CH = 128
|
31 |
+
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
|
32 |
+
VAE_PARAMS_NUM_RES_BLOCKS = 2
|
33 |
+
|
34 |
+
# V2
|
35 |
+
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
36 |
+
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
37 |
+
|
38 |
+
# Diffusersの設定を読み込むための参照モデル
|
39 |
+
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
40 |
+
DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1"
|
41 |
+
|
42 |
+
|
43 |
+
# region StableDiffusion->Diffusersの変換コード
|
44 |
+
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
|
45 |
+
|
46 |
+
|
47 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
48 |
+
"""
|
49 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
50 |
+
"""
|
51 |
+
if n_shave_prefix_segments >= 0:
|
52 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
53 |
+
else:
|
54 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
55 |
+
|
56 |
+
|
57 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
58 |
+
"""
|
59 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
60 |
+
"""
|
61 |
+
mapping = []
|
62 |
+
for old_item in old_list:
|
63 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
64 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
65 |
+
|
66 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
67 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
68 |
+
|
69 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
70 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
71 |
+
|
72 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
73 |
+
|
74 |
+
mapping.append({"old": old_item, "new": new_item})
|
75 |
+
|
76 |
+
return mapping
|
77 |
+
|
78 |
+
|
79 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
80 |
+
"""
|
81 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
82 |
+
"""
|
83 |
+
mapping = []
|
84 |
+
for old_item in old_list:
|
85 |
+
new_item = old_item
|
86 |
+
|
87 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
88 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
89 |
+
|
90 |
+
mapping.append({"old": old_item, "new": new_item})
|
91 |
+
|
92 |
+
return mapping
|
93 |
+
|
94 |
+
|
95 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
96 |
+
"""
|
97 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
98 |
+
"""
|
99 |
+
mapping = []
|
100 |
+
for old_item in old_list:
|
101 |
+
new_item = old_item
|
102 |
+
|
103 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
104 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
105 |
+
|
106 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
107 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
108 |
+
|
109 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
110 |
+
|
111 |
+
mapping.append({"old": old_item, "new": new_item})
|
112 |
+
|
113 |
+
return mapping
|
114 |
+
|
115 |
+
|
116 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
117 |
+
"""
|
118 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
119 |
+
"""
|
120 |
+
mapping = []
|
121 |
+
for old_item in old_list:
|
122 |
+
new_item = old_item
|
123 |
+
|
124 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
125 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
126 |
+
|
127 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
128 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
129 |
+
|
130 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
131 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
134 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
137 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
138 |
+
|
139 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
140 |
+
|
141 |
+
mapping.append({"old": old_item, "new": new_item})
|
142 |
+
|
143 |
+
return mapping
|
144 |
+
|
145 |
+
|
146 |
+
def assign_to_checkpoint(
|
147 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
148 |
+
):
|
149 |
+
"""
|
150 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
151 |
+
to them. It splits attention layers, and takes into account additional replacements
|
152 |
+
that may arise.
|
153 |
+
|
154 |
+
Assigns the weights to the new checkpoint.
|
155 |
+
"""
|
156 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
157 |
+
|
158 |
+
# Splits the attention layers into three variables.
|
159 |
+
if attention_paths_to_split is not None:
|
160 |
+
for path, path_map in attention_paths_to_split.items():
|
161 |
+
old_tensor = old_checkpoint[path]
|
162 |
+
channels = old_tensor.shape[0] // 3
|
163 |
+
|
164 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
165 |
+
|
166 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
167 |
+
|
168 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
169 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
170 |
+
|
171 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
172 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
173 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
174 |
+
|
175 |
+
for path in paths:
|
176 |
+
new_path = path["new"]
|
177 |
+
|
178 |
+
# These have already been assigned
|
179 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
180 |
+
continue
|
181 |
+
|
182 |
+
# Global renaming happens here
|
183 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
184 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
185 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
186 |
+
|
187 |
+
if additional_replacements is not None:
|
188 |
+
for replacement in additional_replacements:
|
189 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
190 |
+
|
191 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
192 |
+
if "proj_attn.weight" in new_path:
|
193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
194 |
+
else:
|
195 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
196 |
+
|
197 |
+
|
198 |
+
def conv_attn_to_linear(checkpoint):
|
199 |
+
keys = list(checkpoint.keys())
|
200 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
201 |
+
for key in keys:
|
202 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
203 |
+
if checkpoint[key].ndim > 2:
|
204 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
205 |
+
elif "proj_attn.weight" in key:
|
206 |
+
if checkpoint[key].ndim > 2:
|
207 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
208 |
+
|
209 |
+
|
210 |
+
def linear_transformer_to_conv(checkpoint):
|
211 |
+
keys = list(checkpoint.keys())
|
212 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
213 |
+
for key in keys:
|
214 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
215 |
+
if checkpoint[key].ndim == 2:
|
216 |
+
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
|
217 |
+
|
218 |
+
|
219 |
+
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
220 |
+
"""
|
221 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
222 |
+
"""
|
223 |
+
|
224 |
+
# extract state_dict for UNet
|
225 |
+
unet_state_dict = {}
|
226 |
+
unet_key = "model.diffusion_model."
|
227 |
+
keys = list(checkpoint.keys())
|
228 |
+
for key in keys:
|
229 |
+
if key.startswith(unet_key):
|
230 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
231 |
+
|
232 |
+
new_checkpoint = {}
|
233 |
+
|
234 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
235 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
236 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
237 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
238 |
+
|
239 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
240 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
241 |
+
|
242 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
243 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
244 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
245 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
246 |
+
|
247 |
+
# Retrieves the keys for the input blocks only
|
248 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
249 |
+
input_blocks = {
|
250 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
|
251 |
+
for layer_id in range(num_input_blocks)
|
252 |
+
}
|
253 |
+
|
254 |
+
# Retrieves the keys for the middle blocks only
|
255 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
256 |
+
middle_blocks = {
|
257 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
|
258 |
+
for layer_id in range(num_middle_blocks)
|
259 |
+
}
|
260 |
+
|
261 |
+
# Retrieves the keys for the output blocks only
|
262 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
263 |
+
output_blocks = {
|
264 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
|
265 |
+
for layer_id in range(num_output_blocks)
|
266 |
+
}
|
267 |
+
|
268 |
+
for i in range(1, num_input_blocks):
|
269 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
270 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
271 |
+
|
272 |
+
resnets = [
|
273 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
274 |
+
]
|
275 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
276 |
+
|
277 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
278 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
279 |
+
f"input_blocks.{i}.0.op.weight"
|
280 |
+
)
|
281 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
282 |
+
f"input_blocks.{i}.0.op.bias"
|
283 |
+
)
|
284 |
+
|
285 |
+
paths = renew_resnet_paths(resnets)
|
286 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
287 |
+
assign_to_checkpoint(
|
288 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
289 |
+
)
|
290 |
+
|
291 |
+
if len(attentions):
|
292 |
+
paths = renew_attention_paths(attentions)
|
293 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
294 |
+
assign_to_checkpoint(
|
295 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
296 |
+
)
|
297 |
+
|
298 |
+
resnet_0 = middle_blocks[0]
|
299 |
+
attentions = middle_blocks[1]
|
300 |
+
resnet_1 = middle_blocks[2]
|
301 |
+
|
302 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
303 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
304 |
+
|
305 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
306 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
307 |
+
|
308 |
+
attentions_paths = renew_attention_paths(attentions)
|
309 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
310 |
+
assign_to_checkpoint(
|
311 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
312 |
+
)
|
313 |
+
|
314 |
+
for i in range(num_output_blocks):
|
315 |
+
block_id = i // (config["layers_per_block"] + 1)
|
316 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
317 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
318 |
+
output_block_list = {}
|
319 |
+
|
320 |
+
for layer in output_block_layers:
|
321 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
322 |
+
if layer_id in output_block_list:
|
323 |
+
output_block_list[layer_id].append(layer_name)
|
324 |
+
else:
|
325 |
+
output_block_list[layer_id] = [layer_name]
|
326 |
+
|
327 |
+
if len(output_block_list) > 1:
|
328 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
329 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
330 |
+
|
331 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
332 |
+
paths = renew_resnet_paths(resnets)
|
333 |
+
|
334 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
335 |
+
assign_to_checkpoint(
|
336 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
337 |
+
)
|
338 |
+
|
339 |
+
# オリジナル:
|
340 |
+
# if ["conv.weight", "conv.bias"] in output_block_list.values():
|
341 |
+
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
342 |
+
|
343 |
+
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
|
344 |
+
for l in output_block_list.values():
|
345 |
+
l.sort()
|
346 |
+
|
347 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
348 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
349 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
350 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
351 |
+
]
|
352 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
353 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
354 |
+
]
|
355 |
+
|
356 |
+
# Clear attentions as they have been attributed above.
|
357 |
+
if len(attentions) == 2:
|
358 |
+
attentions = []
|
359 |
+
|
360 |
+
if len(attentions):
|
361 |
+
paths = renew_attention_paths(attentions)
|
362 |
+
meta_path = {
|
363 |
+
"old": f"output_blocks.{i}.1",
|
364 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
365 |
+
}
|
366 |
+
assign_to_checkpoint(
|
367 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
368 |
+
)
|
369 |
+
else:
|
370 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
371 |
+
for path in resnet_0_paths:
|
372 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
373 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
374 |
+
|
375 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
376 |
+
|
377 |
+
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
378 |
+
if v2:
|
379 |
+
linear_transformer_to_conv(new_checkpoint)
|
380 |
+
|
381 |
+
return new_checkpoint
|
382 |
+
|
383 |
+
|
384 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
385 |
+
# extract state dict for VAE
|
386 |
+
vae_state_dict = {}
|
387 |
+
vae_key = "first_stage_model."
|
388 |
+
keys = list(checkpoint.keys())
|
389 |
+
for key in keys:
|
390 |
+
if key.startswith(vae_key):
|
391 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
392 |
+
# if len(vae_state_dict) == 0:
|
393 |
+
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
|
394 |
+
# vae_state_dict = checkpoint
|
395 |
+
|
396 |
+
new_checkpoint = {}
|
397 |
+
|
398 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
399 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
400 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
401 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
402 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
403 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
404 |
+
|
405 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
406 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
407 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
408 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
409 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
410 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
411 |
+
|
412 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
413 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
414 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
415 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
416 |
+
|
417 |
+
# Retrieves the keys for the encoder down blocks only
|
418 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
419 |
+
down_blocks = {
|
420 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
421 |
+
}
|
422 |
+
|
423 |
+
# Retrieves the keys for the decoder up blocks only
|
424 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
425 |
+
up_blocks = {
|
426 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
427 |
+
}
|
428 |
+
|
429 |
+
for i in range(num_down_blocks):
|
430 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
431 |
+
|
432 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
433 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
434 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
435 |
+
)
|
436 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
437 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
438 |
+
)
|
439 |
+
|
440 |
+
paths = renew_vae_resnet_paths(resnets)
|
441 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
442 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
443 |
+
|
444 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
445 |
+
num_mid_res_blocks = 2
|
446 |
+
for i in range(1, num_mid_res_blocks + 1):
|
447 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
448 |
+
|
449 |
+
paths = renew_vae_resnet_paths(resnets)
|
450 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
451 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
452 |
+
|
453 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
454 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
455 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
456 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
457 |
+
conv_attn_to_linear(new_checkpoint)
|
458 |
+
|
459 |
+
for i in range(num_up_blocks):
|
460 |
+
block_id = num_up_blocks - 1 - i
|
461 |
+
resnets = [
|
462 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
463 |
+
]
|
464 |
+
|
465 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
466 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
467 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
468 |
+
]
|
469 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
470 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
471 |
+
]
|
472 |
+
|
473 |
+
paths = renew_vae_resnet_paths(resnets)
|
474 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
475 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
476 |
+
|
477 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
478 |
+
num_mid_res_blocks = 2
|
479 |
+
for i in range(1, num_mid_res_blocks + 1):
|
480 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
481 |
+
|
482 |
+
paths = renew_vae_resnet_paths(resnets)
|
483 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
484 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
485 |
+
|
486 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
487 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
488 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
489 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
490 |
+
conv_attn_to_linear(new_checkpoint)
|
491 |
+
return new_checkpoint
|
492 |
+
|
493 |
+
|
494 |
+
def create_unet_diffusers_config(v2):
|
495 |
+
"""
|
496 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
497 |
+
"""
|
498 |
+
# unet_params = original_config.model.params.unet_config.params
|
499 |
+
|
500 |
+
block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
|
501 |
+
|
502 |
+
down_block_types = []
|
503 |
+
resolution = 1
|
504 |
+
for i in range(len(block_out_channels)):
|
505 |
+
block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
|
506 |
+
down_block_types.append(block_type)
|
507 |
+
if i != len(block_out_channels) - 1:
|
508 |
+
resolution *= 2
|
509 |
+
|
510 |
+
up_block_types = []
|
511 |
+
for i in range(len(block_out_channels)):
|
512 |
+
block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
|
513 |
+
up_block_types.append(block_type)
|
514 |
+
resolution //= 2
|
515 |
+
|
516 |
+
config = dict(
|
517 |
+
sample_size=UNET_PARAMS_IMAGE_SIZE,
|
518 |
+
in_channels=UNET_PARAMS_IN_CHANNELS,
|
519 |
+
out_channels=UNET_PARAMS_OUT_CHANNELS,
|
520 |
+
down_block_types=tuple(down_block_types),
|
521 |
+
up_block_types=tuple(up_block_types),
|
522 |
+
block_out_channels=tuple(block_out_channels),
|
523 |
+
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
524 |
+
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
525 |
+
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
526 |
+
)
|
527 |
+
|
528 |
+
return config
|
529 |
+
|
530 |
+
|
531 |
+
def create_vae_diffusers_config():
|
532 |
+
"""
|
533 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
534 |
+
"""
|
535 |
+
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
536 |
+
# _ = original_config.model.params.first_stage_config.params.embed_dim
|
537 |
+
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
|
538 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
539 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
540 |
+
|
541 |
+
config = dict(
|
542 |
+
sample_size=VAE_PARAMS_RESOLUTION,
|
543 |
+
in_channels=VAE_PARAMS_IN_CHANNELS,
|
544 |
+
out_channels=VAE_PARAMS_OUT_CH,
|
545 |
+
down_block_types=tuple(down_block_types),
|
546 |
+
up_block_types=tuple(up_block_types),
|
547 |
+
block_out_channels=tuple(block_out_channels),
|
548 |
+
latent_channels=VAE_PARAMS_Z_CHANNELS,
|
549 |
+
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
|
550 |
+
)
|
551 |
+
return config
|
552 |
+
|
553 |
+
|
554 |
+
def convert_ldm_clip_checkpoint_v1(checkpoint):
|
555 |
+
keys = list(checkpoint.keys())
|
556 |
+
text_model_dict = {}
|
557 |
+
for key in keys:
|
558 |
+
if key.startswith("cond_stage_model.transformer"):
|
559 |
+
text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
|
560 |
+
return text_model_dict
|
561 |
+
|
562 |
+
|
563 |
+
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
564 |
+
# 嫌になるくらい違うぞ!
|
565 |
+
def convert_key(key):
|
566 |
+
if not key.startswith("cond_stage_model"):
|
567 |
+
return None
|
568 |
+
|
569 |
+
# common conversion
|
570 |
+
key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
|
571 |
+
key = key.replace("cond_stage_model.model.", "text_model.")
|
572 |
+
|
573 |
+
if "resblocks" in key:
|
574 |
+
# resblocks conversion
|
575 |
+
key = key.replace(".resblocks.", ".layers.")
|
576 |
+
if ".ln_" in key:
|
577 |
+
key = key.replace(".ln_", ".layer_norm")
|
578 |
+
elif ".mlp." in key:
|
579 |
+
key = key.replace(".c_fc.", ".fc1.")
|
580 |
+
key = key.replace(".c_proj.", ".fc2.")
|
581 |
+
elif '.attn.out_proj' in key:
|
582 |
+
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
|
583 |
+
elif '.attn.in_proj' in key:
|
584 |
+
key = None # 特殊なので後で処理する
|
585 |
+
else:
|
586 |
+
raise ValueError(f"unexpected key in SD: {key}")
|
587 |
+
elif '.positional_embedding' in key:
|
588 |
+
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
|
589 |
+
elif '.text_projection' in key:
|
590 |
+
key = None # 使われない???
|
591 |
+
elif '.logit_scale' in key:
|
592 |
+
key = None # 使われない???
|
593 |
+
elif '.token_embedding' in key:
|
594 |
+
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
|
595 |
+
elif '.ln_final' in key:
|
596 |
+
key = key.replace(".ln_final", ".final_layer_norm")
|
597 |
+
return key
|
598 |
+
|
599 |
+
keys = list(checkpoint.keys())
|
600 |
+
new_sd = {}
|
601 |
+
for key in keys:
|
602 |
+
# remove resblocks 23
|
603 |
+
if '.resblocks.23.' in key:
|
604 |
+
continue
|
605 |
+
new_key = convert_key(key)
|
606 |
+
if new_key is None:
|
607 |
+
continue
|
608 |
+
new_sd[new_key] = checkpoint[key]
|
609 |
+
|
610 |
+
# attnの変換
|
611 |
+
for key in keys:
|
612 |
+
if '.resblocks.23.' in key:
|
613 |
+
continue
|
614 |
+
if '.resblocks' in key and '.attn.in_proj_' in key:
|
615 |
+
# 三つに分割
|
616 |
+
values = torch.chunk(checkpoint[key], 3)
|
617 |
+
|
618 |
+
key_suffix = ".weight" if "weight" in key else ".bias"
|
619 |
+
key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
|
620 |
+
key_pfx = key_pfx.replace("_weight", "")
|
621 |
+
key_pfx = key_pfx.replace("_bias", "")
|
622 |
+
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
|
623 |
+
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
|
624 |
+
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
625 |
+
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
626 |
+
|
627 |
+
# rename or add position_ids
|
628 |
+
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
629 |
+
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
630 |
+
# waifu diffusion v1.4
|
631 |
+
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
632 |
+
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
633 |
+
else:
|
634 |
+
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
635 |
+
|
636 |
+
new_sd["text_model.embeddings.position_ids"] = position_ids
|
637 |
+
return new_sd
|
638 |
+
|
639 |
+
# endregion
|
640 |
+
|
641 |
+
|
642 |
+
# region Diffusers->StableDiffusion の変換コード
|
643 |
+
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
|
644 |
+
|
645 |
+
def conv_transformer_to_linear(checkpoint):
|
646 |
+
keys = list(checkpoint.keys())
|
647 |
+
tf_keys = ["proj_in.weight", "proj_out.weight"]
|
648 |
+
for key in keys:
|
649 |
+
if ".".join(key.split(".")[-2:]) in tf_keys:
|
650 |
+
if checkpoint[key].ndim > 2:
|
651 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
652 |
+
|
653 |
+
|
654 |
+
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
|
655 |
+
unet_conversion_map = [
|
656 |
+
# (stable-diffusion, HF Diffusers)
|
657 |
+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
|
658 |
+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
|
659 |
+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
|
660 |
+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
|
661 |
+
("input_blocks.0.0.weight", "conv_in.weight"),
|
662 |
+
("input_blocks.0.0.bias", "conv_in.bias"),
|
663 |
+
("out.0.weight", "conv_norm_out.weight"),
|
664 |
+
("out.0.bias", "conv_norm_out.bias"),
|
665 |
+
("out.2.weight", "conv_out.weight"),
|
666 |
+
("out.2.bias", "conv_out.bias"),
|
667 |
+
]
|
668 |
+
|
669 |
+
unet_conversion_map_resnet = [
|
670 |
+
# (stable-diffusion, HF Diffusers)
|
671 |
+
("in_layers.0", "norm1"),
|
672 |
+
("in_layers.2", "conv1"),
|
673 |
+
("out_layers.0", "norm2"),
|
674 |
+
("out_layers.3", "conv2"),
|
675 |
+
("emb_layers.1", "time_emb_proj"),
|
676 |
+
("skip_connection", "conv_shortcut"),
|
677 |
+
]
|
678 |
+
|
679 |
+
unet_conversion_map_layer = []
|
680 |
+
for i in range(4):
|
681 |
+
# loop over downblocks/upblocks
|
682 |
+
|
683 |
+
for j in range(2):
|
684 |
+
# loop over resnets/attentions for downblocks
|
685 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
686 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
687 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
688 |
+
|
689 |
+
if i < 3:
|
690 |
+
# no attention layers in down_blocks.3
|
691 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
692 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
693 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
694 |
+
|
695 |
+
for j in range(3):
|
696 |
+
# loop over resnets/attentions for upblocks
|
697 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
698 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
699 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
700 |
+
|
701 |
+
if i > 0:
|
702 |
+
# no attention layers in up_blocks.0
|
703 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
704 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
705 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
706 |
+
|
707 |
+
if i < 3:
|
708 |
+
# no downsample in down_blocks.3
|
709 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
710 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
711 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
712 |
+
|
713 |
+
# no upsample in up_blocks.3
|
714 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
715 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
716 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
717 |
+
|
718 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
719 |
+
sd_mid_atn_prefix = "middle_block.1."
|
720 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
721 |
+
|
722 |
+
for j in range(2):
|
723 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
724 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
725 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
726 |
+
|
727 |
+
# buyer beware: this is a *brittle* function,
|
728 |
+
# and correct output requires that all of these pieces interact in
|
729 |
+
# the exact order in which I have arranged them.
|
730 |
+
mapping = {k: k for k in unet_state_dict.keys()}
|
731 |
+
for sd_name, hf_name in unet_conversion_map:
|
732 |
+
mapping[hf_name] = sd_name
|
733 |
+
for k, v in mapping.items():
|
734 |
+
if "resnets" in k:
|
735 |
+
for sd_part, hf_part in unet_conversion_map_resnet:
|
736 |
+
v = v.replace(hf_part, sd_part)
|
737 |
+
mapping[k] = v
|
738 |
+
for k, v in mapping.items():
|
739 |
+
for sd_part, hf_part in unet_conversion_map_layer:
|
740 |
+
v = v.replace(hf_part, sd_part)
|
741 |
+
mapping[k] = v
|
742 |
+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
|
743 |
+
|
744 |
+
if v2:
|
745 |
+
conv_transformer_to_linear(new_state_dict)
|
746 |
+
|
747 |
+
return new_state_dict
|
748 |
+
|
749 |
+
|
750 |
+
# ================#
|
751 |
+
# VAE Conversion #
|
752 |
+
# ================#
|
753 |
+
|
754 |
+
def reshape_weight_for_sd(w):
|
755 |
+
# convert HF linear weights to SD conv2d weights
|
756 |
+
return w.reshape(*w.shape, 1, 1)
|
757 |
+
|
758 |
+
|
759 |
+
def convert_vae_state_dict(vae_state_dict):
|
760 |
+
vae_conversion_map = [
|
761 |
+
# (stable-diffusion, HF Diffusers)
|
762 |
+
("nin_shortcut", "conv_shortcut"),
|
763 |
+
("norm_out", "conv_norm_out"),
|
764 |
+
("mid.attn_1.", "mid_block.attentions.0."),
|
765 |
+
]
|
766 |
+
|
767 |
+
for i in range(4):
|
768 |
+
# down_blocks have two resnets
|
769 |
+
for j in range(2):
|
770 |
+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
|
771 |
+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
|
772 |
+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
|
773 |
+
|
774 |
+
if i < 3:
|
775 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
|
776 |
+
sd_downsample_prefix = f"down.{i}.downsample."
|
777 |
+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
778 |
+
|
779 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
780 |
+
sd_upsample_prefix = f"up.{3-i}.upsample."
|
781 |
+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
782 |
+
|
783 |
+
# up_blocks have three resnets
|
784 |
+
# also, up blocks in hf are numbered in reverse from sd
|
785 |
+
for j in range(3):
|
786 |
+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
787 |
+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
788 |
+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
789 |
+
|
790 |
+
# this part accounts for mid blocks in both the encoder and the decoder
|
791 |
+
for i in range(2):
|
792 |
+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
793 |
+
sd_mid_res_prefix = f"mid.block_{i+1}."
|
794 |
+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
795 |
+
|
796 |
+
vae_conversion_map_attn = [
|
797 |
+
# (stable-diffusion, HF Diffusers)
|
798 |
+
("norm.", "group_norm."),
|
799 |
+
("q.", "query."),
|
800 |
+
("k.", "key."),
|
801 |
+
("v.", "value."),
|
802 |
+
("proj_out.", "proj_attn."),
|
803 |
+
]
|
804 |
+
|
805 |
+
mapping = {k: k for k in vae_state_dict.keys()}
|
806 |
+
for k, v in mapping.items():
|
807 |
+
for sd_part, hf_part in vae_conversion_map:
|
808 |
+
v = v.replace(hf_part, sd_part)
|
809 |
+
mapping[k] = v
|
810 |
+
for k, v in mapping.items():
|
811 |
+
if "attentions" in k:
|
812 |
+
for sd_part, hf_part in vae_conversion_map_attn:
|
813 |
+
v = v.replace(hf_part, sd_part)
|
814 |
+
mapping[k] = v
|
815 |
+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
|
816 |
+
weights_to_convert = ["q", "k", "v", "proj_out"]
|
817 |
+
for k, v in new_state_dict.items():
|
818 |
+
for weight_name in weights_to_convert:
|
819 |
+
if f"mid.attn_1.{weight_name}.weight" in k:
|
820 |
+
# print(f"Reshaping {k} for SD format")
|
821 |
+
new_state_dict[k] = reshape_weight_for_sd(v)
|
822 |
+
|
823 |
+
return new_state_dict
|
824 |
+
|
825 |
+
|
826 |
+
# endregion
|
827 |
+
|
828 |
+
# region 自作のモデル読み書きなど
|
829 |
+
|
830 |
+
def is_safetensors(path):
|
831 |
+
return os.path.splitext(path)[1].lower() == '.safetensors'
|
832 |
+
|
833 |
+
|
834 |
+
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
835 |
+
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
|
836 |
+
TEXT_ENCODER_KEY_REPLACEMENTS = [
|
837 |
+
('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
|
838 |
+
('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
|
839 |
+
('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
|
840 |
+
]
|
841 |
+
|
842 |
+
if is_safetensors(ckpt_path):
|
843 |
+
checkpoint = None
|
844 |
+
state_dict = load_file(ckpt_path, "cpu")
|
845 |
+
else:
|
846 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
847 |
+
if "state_dict" in checkpoint:
|
848 |
+
state_dict = checkpoint["state_dict"]
|
849 |
+
else:
|
850 |
+
state_dict = checkpoint
|
851 |
+
checkpoint = None
|
852 |
+
|
853 |
+
key_reps = []
|
854 |
+
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
|
855 |
+
for key in state_dict.keys():
|
856 |
+
if key.startswith(rep_from):
|
857 |
+
new_key = rep_to + key[len(rep_from):]
|
858 |
+
key_reps.append((key, new_key))
|
859 |
+
|
860 |
+
for key, new_key in key_reps:
|
861 |
+
state_dict[new_key] = state_dict[key]
|
862 |
+
del state_dict[key]
|
863 |
+
|
864 |
+
return checkpoint, state_dict
|
865 |
+
|
866 |
+
|
867 |
+
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
868 |
+
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
869 |
+
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
870 |
+
if dtype is not None:
|
871 |
+
for k, v in state_dict.items():
|
872 |
+
if type(v) is torch.Tensor:
|
873 |
+
state_dict[k] = v.to(dtype)
|
874 |
+
|
875 |
+
# Convert the UNet2DConditionModel model.
|
876 |
+
unet_config = create_unet_diffusers_config(v2)
|
877 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
878 |
+
|
879 |
+
unet = UNet2DConditionModel(**unet_config)
|
880 |
+
info = unet.load_state_dict(converted_unet_checkpoint)
|
881 |
+
print("loading u-net:", info)
|
882 |
+
|
883 |
+
# Convert the VAE model.
|
884 |
+
vae_config = create_vae_diffusers_config()
|
885 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
|
886 |
+
|
887 |
+
vae = AutoencoderKL(**vae_config)
|
888 |
+
info = vae.load_state_dict(converted_vae_checkpoint)
|
889 |
+
print("loading vae:", info)
|
890 |
+
|
891 |
+
# convert text_model
|
892 |
+
if v2:
|
893 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
|
894 |
+
cfg = CLIPTextConfig(
|
895 |
+
vocab_size=49408,
|
896 |
+
hidden_size=1024,
|
897 |
+
intermediate_size=4096,
|
898 |
+
num_hidden_layers=23,
|
899 |
+
num_attention_heads=16,
|
900 |
+
max_position_embeddings=77,
|
901 |
+
hidden_act="gelu",
|
902 |
+
layer_norm_eps=1e-05,
|
903 |
+
dropout=0.0,
|
904 |
+
attention_dropout=0.0,
|
905 |
+
initializer_range=0.02,
|
906 |
+
initializer_factor=1.0,
|
907 |
+
pad_token_id=1,
|
908 |
+
bos_token_id=0,
|
909 |
+
eos_token_id=2,
|
910 |
+
model_type="clip_text_model",
|
911 |
+
projection_dim=512,
|
912 |
+
torch_dtype="float32",
|
913 |
+
transformers_version="4.25.0.dev0",
|
914 |
+
)
|
915 |
+
text_model = CLIPTextModel._from_config(cfg)
|
916 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
917 |
+
else:
|
918 |
+
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
919 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
920 |
+
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
921 |
+
print("loading text encoder:", info)
|
922 |
+
|
923 |
+
return text_model, vae, unet
|
924 |
+
|
925 |
+
|
926 |
+
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
|
927 |
+
def convert_key(key):
|
928 |
+
# position_idsの除去
|
929 |
+
if ".position_ids" in key:
|
930 |
+
return None
|
931 |
+
|
932 |
+
# common
|
933 |
+
key = key.replace("text_model.encoder.", "transformer.")
|
934 |
+
key = key.replace("text_model.", "")
|
935 |
+
if "layers" in key:
|
936 |
+
# resblocks conversion
|
937 |
+
key = key.replace(".layers.", ".resblocks.")
|
938 |
+
if ".layer_norm" in key:
|
939 |
+
key = key.replace(".layer_norm", ".ln_")
|
940 |
+
elif ".mlp." in key:
|
941 |
+
key = key.replace(".fc1.", ".c_fc.")
|
942 |
+
key = key.replace(".fc2.", ".c_proj.")
|
943 |
+
elif '.self_attn.out_proj' in key:
|
944 |
+
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
|
945 |
+
elif '.self_attn.' in key:
|
946 |
+
key = None # 特殊なので後で処理する
|
947 |
+
else:
|
948 |
+
raise ValueError(f"unexpected key in DiffUsers model: {key}")
|
949 |
+
elif '.position_embedding' in key:
|
950 |
+
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
|
951 |
+
elif '.token_embedding' in key:
|
952 |
+
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
|
953 |
+
elif 'final_layer_norm' in key:
|
954 |
+
key = key.replace("final_layer_norm", "ln_final")
|
955 |
+
return key
|
956 |
+
|
957 |
+
keys = list(checkpoint.keys())
|
958 |
+
new_sd = {}
|
959 |
+
for key in keys:
|
960 |
+
new_key = convert_key(key)
|
961 |
+
if new_key is None:
|
962 |
+
continue
|
963 |
+
new_sd[new_key] = checkpoint[key]
|
964 |
+
|
965 |
+
# attnの変換
|
966 |
+
for key in keys:
|
967 |
+
if 'layers' in key and 'q_proj' in key:
|
968 |
+
# 三つを結合
|
969 |
+
key_q = key
|
970 |
+
key_k = key.replace("q_proj", "k_proj")
|
971 |
+
key_v = key.replace("q_proj", "v_proj")
|
972 |
+
|
973 |
+
value_q = checkpoint[key_q]
|
974 |
+
value_k = checkpoint[key_k]
|
975 |
+
value_v = checkpoint[key_v]
|
976 |
+
value = torch.cat([value_q, value_k, value_v])
|
977 |
+
|
978 |
+
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
|
979 |
+
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
|
980 |
+
new_sd[new_key] = value
|
981 |
+
|
982 |
+
# 最後の層などを捏造するか
|
983 |
+
if make_dummy_weights:
|
984 |
+
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
985 |
+
keys = list(new_sd.keys())
|
986 |
+
for key in keys:
|
987 |
+
if key.startswith("transformer.resblocks.22."):
|
988 |
+
new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる
|
989 |
+
|
990 |
+
# Diffusersに含まれない重みを作っておく
|
991 |
+
new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
|
992 |
+
new_sd['logit_scale'] = torch.tensor(1)
|
993 |
+
|
994 |
+
return new_sd
|
995 |
+
|
996 |
+
|
997 |
+
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
|
998 |
+
if ckpt_path is not None:
|
999 |
+
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
|
1000 |
+
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
1001 |
+
if checkpoint is None: # safetensors または state_dictのckpt
|
1002 |
+
checkpoint = {}
|
1003 |
+
strict = False
|
1004 |
+
else:
|
1005 |
+
strict = True
|
1006 |
+
if "state_dict" in state_dict:
|
1007 |
+
del state_dict["state_dict"]
|
1008 |
+
else:
|
1009 |
+
# 新しく作る
|
1010 |
+
assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint"
|
1011 |
+
checkpoint = {}
|
1012 |
+
state_dict = {}
|
1013 |
+
strict = False
|
1014 |
+
|
1015 |
+
def update_sd(prefix, sd):
|
1016 |
+
for k, v in sd.items():
|
1017 |
+
key = prefix + k
|
1018 |
+
assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
|
1019 |
+
if save_dtype is not None:
|
1020 |
+
v = v.detach().clone().to("cpu").to(save_dtype)
|
1021 |
+
state_dict[key] = v
|
1022 |
+
|
1023 |
+
# Convert the UNet model
|
1024 |
+
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
1025 |
+
update_sd("model.diffusion_model.", unet_state_dict)
|
1026 |
+
|
1027 |
+
# Convert the text encoder model
|
1028 |
+
if v2:
|
1029 |
+
make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製��て作るなどダミーの重みを入れる
|
1030 |
+
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
|
1031 |
+
update_sd("cond_stage_model.model.", text_enc_dict)
|
1032 |
+
else:
|
1033 |
+
text_enc_dict = text_encoder.state_dict()
|
1034 |
+
update_sd("cond_stage_model.transformer.", text_enc_dict)
|
1035 |
+
|
1036 |
+
# Convert the VAE
|
1037 |
+
if vae is not None:
|
1038 |
+
vae_dict = convert_vae_state_dict(vae.state_dict())
|
1039 |
+
update_sd("first_stage_model.", vae_dict)
|
1040 |
+
|
1041 |
+
# Put together new checkpoint
|
1042 |
+
key_count = len(state_dict.keys())
|
1043 |
+
new_ckpt = {'state_dict': state_dict}
|
1044 |
+
|
1045 |
+
if 'epoch' in checkpoint:
|
1046 |
+
epochs += checkpoint['epoch']
|
1047 |
+
if 'global_step' in checkpoint:
|
1048 |
+
steps += checkpoint['global_step']
|
1049 |
+
|
1050 |
+
new_ckpt['epoch'] = epochs
|
1051 |
+
new_ckpt['global_step'] = steps
|
1052 |
+
|
1053 |
+
if is_safetensors(output_file):
|
1054 |
+
# TODO Tensor以外のdictの値を削除したほうがいいか
|
1055 |
+
save_file(state_dict, output_file)
|
1056 |
+
else:
|
1057 |
+
torch.save(new_ckpt, output_file)
|
1058 |
+
|
1059 |
+
return key_count
|
1060 |
+
|
1061 |
+
|
1062 |
+
def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False):
|
1063 |
+
if pretrained_model_name_or_path is None:
|
1064 |
+
# load default settings for v1/v2
|
1065 |
+
if v2:
|
1066 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
|
1067 |
+
else:
|
1068 |
+
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
|
1069 |
+
|
1070 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
|
1071 |
+
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
|
1072 |
+
if vae is None:
|
1073 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
1074 |
+
|
1075 |
+
pipeline = StableDiffusionPipeline(
|
1076 |
+
unet=unet,
|
1077 |
+
text_encoder=text_encoder,
|
1078 |
+
vae=vae,
|
1079 |
+
scheduler=scheduler,
|
1080 |
+
tokenizer=tokenizer,
|
1081 |
+
safety_checker=None,
|
1082 |
+
feature_extractor=None,
|
1083 |
+
requires_safety_checker=None,
|
1084 |
+
)
|
1085 |
+
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
|
1086 |
+
|
1087 |
+
|
1088 |
+
VAE_PREFIX = "first_stage_model."
|
1089 |
+
|
1090 |
+
|
1091 |
+
def load_vae(vae_id, dtype):
|
1092 |
+
print(f"load VAE: {vae_id}")
|
1093 |
+
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
1094 |
+
# Diffusers local/remote
|
1095 |
+
try:
|
1096 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
1097 |
+
except EnvironmentError as e:
|
1098 |
+
print(f"exception occurs in loading vae: {e}")
|
1099 |
+
print("retry with subfolder='vae'")
|
1100 |
+
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
1101 |
+
return vae
|
1102 |
+
|
1103 |
+
# local
|
1104 |
+
vae_config = create_vae_diffusers_config()
|
1105 |
+
|
1106 |
+
if vae_id.endswith(".bin"):
|
1107 |
+
# SD 1.5 VAE on Huggingface
|
1108 |
+
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
1109 |
+
else:
|
1110 |
+
# StableDiffusion
|
1111 |
+
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
1112 |
+
else torch.load(vae_id, map_location="cpu"))
|
1113 |
+
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
1114 |
+
|
1115 |
+
# vae only or full model
|
1116 |
+
full_model = False
|
1117 |
+
for vae_key in vae_sd:
|
1118 |
+
if vae_key.startswith(VAE_PREFIX):
|
1119 |
+
full_model = True
|
1120 |
+
break
|
1121 |
+
if not full_model:
|
1122 |
+
sd = {}
|
1123 |
+
for key, value in vae_sd.items():
|
1124 |
+
sd[VAE_PREFIX + key] = value
|
1125 |
+
vae_sd = sd
|
1126 |
+
del sd
|
1127 |
+
|
1128 |
+
# Convert the VAE model.
|
1129 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
1130 |
+
|
1131 |
+
vae = AutoencoderKL(**vae_config)
|
1132 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
1133 |
+
return vae
|
1134 |
+
|
1135 |
+
# endregion
|
1136 |
+
|
1137 |
+
|
1138 |
+
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
1139 |
+
max_width, max_height = max_reso
|
1140 |
+
max_area = (max_width // divisible) * (max_height // divisible)
|
1141 |
+
|
1142 |
+
resos = set()
|
1143 |
+
|
1144 |
+
size = int(math.sqrt(max_area)) * divisible
|
1145 |
+
resos.add((size, size))
|
1146 |
+
|
1147 |
+
size = min_size
|
1148 |
+
while size <= max_size:
|
1149 |
+
width = size
|
1150 |
+
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
1151 |
+
resos.add((width, height))
|
1152 |
+
resos.add((height, width))
|
1153 |
+
|
1154 |
+
# # make additional resos
|
1155 |
+
# if width >= height and width - divisible >= min_size:
|
1156 |
+
# resos.add((width - divisible, height))
|
1157 |
+
# resos.add((height, width - divisible))
|
1158 |
+
# if height >= width and height - divisible >= min_size:
|
1159 |
+
# resos.add((width, height - divisible))
|
1160 |
+
# resos.add((height - divisible, width))
|
1161 |
+
|
1162 |
+
size += divisible
|
1163 |
+
|
1164 |
+
resos = list(resos)
|
1165 |
+
resos.sort()
|
1166 |
+
return resos
|
1167 |
+
|
1168 |
+
|
1169 |
+
if __name__ == '__main__':
|
1170 |
+
resos = make_bucket_resolutions((512, 768))
|
1171 |
+
print(len(resos))
|
1172 |
+
print(resos)
|
1173 |
+
aspect_ratios = [w / h for w, h in resos]
|
1174 |
+
print(aspect_ratios)
|
1175 |
+
|
1176 |
+
ars = set()
|
1177 |
+
for ar in aspect_ratios:
|
1178 |
+
if ar in ars:
|
1179 |
+
print("error! duplicate ar:", ar)
|
1180 |
+
ars.add(ar)
|
Lora/lib/qwerty.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from safetensors.torch import load_file
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from pathlib import Path
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def cal_cross_attn(to_q, to_k, to_v, rand_input):
|
9 |
+
hidden_dim, embed_dim = to_q.shape
|
10 |
+
attn_to_q = nn.Linear(hidden_dim, embed_dim, bias=False)
|
11 |
+
attn_to_k = nn.Linear(hidden_dim, embed_dim, bias=False)
|
12 |
+
attn_to_v = nn.Linear(hidden_dim, embed_dim, bias=False)
|
13 |
+
attn_to_q.load_state_dict({"weight": to_q})
|
14 |
+
attn_to_k.load_state_dict({"weight": to_k})
|
15 |
+
attn_to_v.load_state_dict({"weight": to_v})
|
16 |
+
|
17 |
+
return torch.einsum(
|
18 |
+
"ik, jk -> ik",
|
19 |
+
F.softmax(torch.einsum("ij, kj -> ik", attn_to_q(rand_input), attn_to_k(rand_input)), dim=-1),
|
20 |
+
attn_to_v(rand_input)
|
21 |
+
)
|
22 |
+
|
23 |
+
def model_hash(filename):
|
24 |
+
try:
|
25 |
+
with open(filename, "rb") as file:
|
26 |
+
import hashlib
|
27 |
+
m = hashlib.sha256()
|
28 |
+
|
29 |
+
file.seek(0x100000)
|
30 |
+
m.update(file.read(0x10000))
|
31 |
+
return m.hexdigest()[0:8]
|
32 |
+
except FileNotFoundError:
|
33 |
+
return 'NOFILE'
|
34 |
+
|
35 |
+
def load_model(path):
|
36 |
+
if path.suffix == ".safetensors":
|
37 |
+
return load_file(path, device="cpu")
|
38 |
+
else:
|
39 |
+
ckpt = torch.load(path, map_location="cpu")
|
40 |
+
return ckpt["state_dict"] if "state_dict" in ckpt else ckpt
|
41 |
+
|
42 |
+
def eval(model, n, input):
|
43 |
+
qk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"
|
44 |
+
uk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_k.weight"
|
45 |
+
vk = f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_v.weight"
|
46 |
+
atoq, atok, atov = model[qk], model[uk], model[vk]
|
47 |
+
|
48 |
+
attn = cal_cross_attn(atoq, atok, atov, input)
|
49 |
+
return attn
|
50 |
+
|
51 |
+
def main():
|
52 |
+
file1 = Path(sys.argv[1])
|
53 |
+
files = sys.argv[2:]
|
54 |
+
|
55 |
+
seed = 114514
|
56 |
+
torch.manual_seed(seed)
|
57 |
+
print(f"seed: {seed}")
|
58 |
+
|
59 |
+
model_a = load_model(file1)
|
60 |
+
|
61 |
+
print()
|
62 |
+
print(f"base: {file1.name} [{model_hash(file1)}]")
|
63 |
+
print()
|
64 |
+
|
65 |
+
map_attn_a = {}
|
66 |
+
map_rand_input = {}
|
67 |
+
for n in range(3, 11):
|
68 |
+
hidden_dim, embed_dim = model_a[f"model.diffusion_model.output_blocks.{n}.1.transformer_blocks.0.attn1.to_q.weight"].shape
|
69 |
+
rand_input = torch.randn([embed_dim, hidden_dim])
|
70 |
+
|
71 |
+
map_attn_a[n] = eval(model_a, n, rand_input)
|
72 |
+
map_rand_input[n] = rand_input
|
73 |
+
|
74 |
+
del model_a
|
75 |
+
|
76 |
+
for file2 in files:
|
77 |
+
file2 = Path(file2)
|
78 |
+
model_b = load_model(file2)
|
79 |
+
|
80 |
+
sims = []
|
81 |
+
for n in range(3, 11):
|
82 |
+
attn_a = map_attn_a[n]
|
83 |
+
attn_b = eval(model_b, n, map_rand_input[n])
|
84 |
+
|
85 |
+
sim = torch.mean(torch.cosine_similarity(attn_a, attn_b))
|
86 |
+
sims.append(sim)
|
87 |
+
|
88 |
+
print(f"{file2} [{model_hash(file2)}] - {torch.mean(torch.stack(sims)) * 1e2:.2f}%")
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
main()
|
Lora/lib/train_util.py
ADDED
@@ -0,0 +1,1766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# common functions for training
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import shutil
|
6 |
+
import time
|
7 |
+
from typing import Dict, List, NamedTuple, Tuple
|
8 |
+
from accelerate import Accelerator
|
9 |
+
from torch.autograd.function import Function
|
10 |
+
import glob
|
11 |
+
import math
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
import hashlib
|
15 |
+
from io import BytesIO
|
16 |
+
|
17 |
+
from tqdm import tqdm
|
18 |
+
import torch
|
19 |
+
from torchvision import transforms
|
20 |
+
from transformers import CLIPTokenizer
|
21 |
+
import diffusers
|
22 |
+
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
23 |
+
import albumentations as albu
|
24 |
+
import numpy as np
|
25 |
+
from PIL import Image
|
26 |
+
import cv2
|
27 |
+
from einops import rearrange
|
28 |
+
from torch import einsum
|
29 |
+
import safetensors.torch
|
30 |
+
|
31 |
+
import train_util
|
32 |
+
|
33 |
+
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
34 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
35 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
36 |
+
|
37 |
+
# checkpointファイル名
|
38 |
+
EPOCH_STATE_NAME = "{}-{:06d}-state"
|
39 |
+
EPOCH_FILE_NAME = "{}-{:06d}"
|
40 |
+
EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}"
|
41 |
+
LAST_STATE_NAME = "{}-state"
|
42 |
+
DEFAULT_EPOCH_NAME = "epoch"
|
43 |
+
DEFAULT_LAST_OUTPUT_NAME = "last"
|
44 |
+
|
45 |
+
# region dataset
|
46 |
+
|
47 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
|
48 |
+
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
|
49 |
+
|
50 |
+
|
51 |
+
class ImageInfo():
|
52 |
+
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
53 |
+
self.image_key: str = image_key
|
54 |
+
self.num_repeats: int = num_repeats
|
55 |
+
self.caption: str = caption
|
56 |
+
self.is_reg: bool = is_reg
|
57 |
+
self.absolute_path: str = absolute_path
|
58 |
+
self.image_size: Tuple[int, int] = None
|
59 |
+
self.resized_size: Tuple[int, int] = None
|
60 |
+
self.bucket_reso: Tuple[int, int] = None
|
61 |
+
self.latents: torch.Tensor = None
|
62 |
+
self.latents_flipped: torch.Tensor = None
|
63 |
+
self.latents_npz: str = None
|
64 |
+
self.latents_npz_flipped: str = None
|
65 |
+
|
66 |
+
|
67 |
+
class BucketManager():
|
68 |
+
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
69 |
+
self.no_upscale = no_upscale
|
70 |
+
if max_reso is None:
|
71 |
+
self.max_reso = None
|
72 |
+
self.max_area = None
|
73 |
+
else:
|
74 |
+
self.max_reso = max_reso
|
75 |
+
self.max_area = max_reso[0] * max_reso[1]
|
76 |
+
self.min_size = min_size
|
77 |
+
self.max_size = max_size
|
78 |
+
self.reso_steps = reso_steps
|
79 |
+
|
80 |
+
self.resos = []
|
81 |
+
self.reso_to_id = {}
|
82 |
+
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
83 |
+
|
84 |
+
def add_image(self, reso, image):
|
85 |
+
bucket_id = self.reso_to_id[reso]
|
86 |
+
self.buckets[bucket_id].append(image)
|
87 |
+
|
88 |
+
def shuffle(self):
|
89 |
+
for bucket in self.buckets:
|
90 |
+
random.shuffle(bucket)
|
91 |
+
|
92 |
+
def sort(self):
|
93 |
+
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
|
94 |
+
sorted_resos = self.resos.copy()
|
95 |
+
sorted_resos.sort()
|
96 |
+
|
97 |
+
sorted_buckets = []
|
98 |
+
sorted_reso_to_id = {}
|
99 |
+
for i, reso in enumerate(sorted_resos):
|
100 |
+
bucket_id = self.reso_to_id[reso]
|
101 |
+
sorted_buckets.append(self.buckets[bucket_id])
|
102 |
+
sorted_reso_to_id[reso] = i
|
103 |
+
|
104 |
+
self.resos = sorted_resos
|
105 |
+
self.buckets = sorted_buckets
|
106 |
+
self.reso_to_id = sorted_reso_to_id
|
107 |
+
|
108 |
+
def make_buckets(self):
|
109 |
+
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
|
110 |
+
self.set_predefined_resos(resos)
|
111 |
+
|
112 |
+
def set_predefined_resos(self, resos):
|
113 |
+
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
114 |
+
self.predefined_resos = resos.copy()
|
115 |
+
self.predefined_resos_set = set(resos)
|
116 |
+
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
117 |
+
|
118 |
+
def add_if_new_reso(self, reso):
|
119 |
+
if reso not in self.reso_to_id:
|
120 |
+
bucket_id = len(self.resos)
|
121 |
+
self.reso_to_id[reso] = bucket_id
|
122 |
+
self.resos.append(reso)
|
123 |
+
self.buckets.append([])
|
124 |
+
# print(reso, bucket_id, len(self.buckets))
|
125 |
+
|
126 |
+
def round_to_steps(self, x):
|
127 |
+
x = int(x + .5)
|
128 |
+
return x - x % self.reso_steps
|
129 |
+
|
130 |
+
def select_bucket(self, image_width, image_height):
|
131 |
+
aspect_ratio = image_width / image_height
|
132 |
+
if not self.no_upscale:
|
133 |
+
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
|
134 |
+
reso = (image_width, image_height)
|
135 |
+
if reso in self.predefined_resos_set:
|
136 |
+
pass
|
137 |
+
else:
|
138 |
+
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
139 |
+
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
140 |
+
reso = self.predefined_resos[predefined_bucket_id]
|
141 |
+
|
142 |
+
ar_reso = reso[0] / reso[1]
|
143 |
+
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
|
144 |
+
scale = reso[1] / image_height
|
145 |
+
else:
|
146 |
+
scale = reso[0] / image_width
|
147 |
+
|
148 |
+
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
|
149 |
+
# print("use predef", image_width, image_height, reso, resized_size)
|
150 |
+
else:
|
151 |
+
if image_width * image_height > self.max_area:
|
152 |
+
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
|
153 |
+
resized_width = math.sqrt(self.max_area * aspect_ratio)
|
154 |
+
resized_height = self.max_area / resized_width
|
155 |
+
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
|
156 |
+
|
157 |
+
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
|
158 |
+
# 元のbucketingと同じロジック
|
159 |
+
b_width_rounded = self.round_to_steps(resized_width)
|
160 |
+
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
|
161 |
+
ar_width_rounded = b_width_rounded / b_height_in_wr
|
162 |
+
|
163 |
+
b_height_rounded = self.round_to_steps(resized_height)
|
164 |
+
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
|
165 |
+
ar_height_rounded = b_width_in_hr / b_height_rounded
|
166 |
+
|
167 |
+
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
|
168 |
+
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
|
169 |
+
|
170 |
+
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
|
171 |
+
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
|
172 |
+
else:
|
173 |
+
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
|
174 |
+
# print(resized_size)
|
175 |
+
else:
|
176 |
+
resized_size = (image_width, image_height) # リサイズは不要
|
177 |
+
|
178 |
+
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
|
179 |
+
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
|
180 |
+
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
|
181 |
+
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
|
182 |
+
|
183 |
+
reso = (bucket_width, bucket_height)
|
184 |
+
|
185 |
+
self.add_if_new_reso(reso)
|
186 |
+
|
187 |
+
ar_error = (reso[0] / reso[1]) - aspect_ratio
|
188 |
+
return reso, resized_size, ar_error
|
189 |
+
|
190 |
+
|
191 |
+
class BucketBatchIndex(NamedTuple):
|
192 |
+
bucket_index: int
|
193 |
+
bucket_batch_size: int
|
194 |
+
batch_index: int
|
195 |
+
|
196 |
+
|
197 |
+
class BaseDataset(torch.utils.data.Dataset):
|
198 |
+
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.tokenizer: CLIPTokenizer = tokenizer
|
201 |
+
self.max_token_length = max_token_length
|
202 |
+
self.shuffle_caption = shuffle_caption
|
203 |
+
self.shuffle_keep_tokens = shuffle_keep_tokens
|
204 |
+
# width/height is used when enable_bucket==False
|
205 |
+
self.width, self.height = (None, None) if resolution is None else resolution
|
206 |
+
self.face_crop_aug_range = face_crop_aug_range
|
207 |
+
self.flip_aug = flip_aug
|
208 |
+
self.color_aug = color_aug
|
209 |
+
self.debug_dataset = debug_dataset
|
210 |
+
self.random_crop = random_crop
|
211 |
+
self.token_padding_disabled = False
|
212 |
+
self.dataset_dirs_info = {}
|
213 |
+
self.reg_dataset_dirs_info = {}
|
214 |
+
self.tag_frequency = {}
|
215 |
+
|
216 |
+
self.enable_bucket = False
|
217 |
+
self.bucket_manager: BucketManager = None # not initialized
|
218 |
+
self.min_bucket_reso = None
|
219 |
+
self.max_bucket_reso = None
|
220 |
+
self.bucket_reso_steps = None
|
221 |
+
self.bucket_no_upscale = None
|
222 |
+
self.bucket_info = None # for metadata
|
223 |
+
|
224 |
+
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
225 |
+
|
226 |
+
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
227 |
+
self.dropout_rate: float = 0
|
228 |
+
self.dropout_every_n_epochs: int = None
|
229 |
+
|
230 |
+
# augmentation
|
231 |
+
flip_p = 0.5 if flip_aug else 0.0
|
232 |
+
if color_aug:
|
233 |
+
# わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
|
234 |
+
self.aug = albu.Compose([
|
235 |
+
albu.OneOf([
|
236 |
+
albu.HueSaturationValue(8, 0, 0, p=.5),
|
237 |
+
albu.RandomGamma((95, 105), p=.5),
|
238 |
+
], p=.33),
|
239 |
+
albu.HorizontalFlip(p=flip_p)
|
240 |
+
], p=1.)
|
241 |
+
elif flip_aug:
|
242 |
+
self.aug = albu.Compose([
|
243 |
+
albu.HorizontalFlip(p=flip_p)
|
244 |
+
], p=1.)
|
245 |
+
else:
|
246 |
+
self.aug = None
|
247 |
+
|
248 |
+
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
|
249 |
+
|
250 |
+
self.image_data: Dict[str, ImageInfo] = {}
|
251 |
+
|
252 |
+
self.replacements = {}
|
253 |
+
|
254 |
+
def set_current_epoch(self, epoch):
|
255 |
+
self.current_epoch = epoch
|
256 |
+
|
257 |
+
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
258 |
+
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(と��うことにしておく)
|
259 |
+
self.dropout_rate = dropout_rate
|
260 |
+
self.dropout_every_n_epochs = dropout_every_n_epochs
|
261 |
+
self.tag_dropout_rate = tag_dropout_rate
|
262 |
+
|
263 |
+
def set_tag_frequency(self, dir_name, captions):
|
264 |
+
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
265 |
+
self.tag_frequency[dir_name] = frequency_for_dir
|
266 |
+
for caption in captions:
|
267 |
+
for tag in caption.split(","):
|
268 |
+
if tag and not tag.isspace():
|
269 |
+
tag = tag.lower()
|
270 |
+
frequency = frequency_for_dir.get(tag, 0)
|
271 |
+
frequency_for_dir[tag] = frequency + 1
|
272 |
+
|
273 |
+
def disable_token_padding(self):
|
274 |
+
self.token_padding_disabled = True
|
275 |
+
|
276 |
+
def add_replacement(self, str_from, str_to):
|
277 |
+
self.replacements[str_from] = str_to
|
278 |
+
|
279 |
+
def process_caption(self, caption):
|
280 |
+
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
281 |
+
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
282 |
+
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
283 |
+
|
284 |
+
if is_drop_out:
|
285 |
+
caption = ""
|
286 |
+
else:
|
287 |
+
if self.shuffle_caption:
|
288 |
+
def dropout_tags(tokens):
|
289 |
+
if self.tag_dropout_rate <= 0:
|
290 |
+
return tokens
|
291 |
+
l = []
|
292 |
+
for token in tokens:
|
293 |
+
if random.random() >= self.tag_dropout_rate:
|
294 |
+
l.append(token)
|
295 |
+
return l
|
296 |
+
|
297 |
+
tokens = [t.strip() for t in caption.strip().split(",")]
|
298 |
+
if self.shuffle_keep_tokens is None:
|
299 |
+
random.shuffle(tokens)
|
300 |
+
tokens = dropout_tags(tokens)
|
301 |
+
else:
|
302 |
+
if len(tokens) > self.shuffle_keep_tokens:
|
303 |
+
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
304 |
+
tokens = tokens[self.shuffle_keep_tokens:]
|
305 |
+
random.shuffle(tokens)
|
306 |
+
tokens = dropout_tags(tokens)
|
307 |
+
|
308 |
+
tokens = keep_tokens + tokens
|
309 |
+
caption = ", ".join(tokens)
|
310 |
+
|
311 |
+
# textual inversion対応
|
312 |
+
for str_from, str_to in self.replacements.items():
|
313 |
+
if str_from == "":
|
314 |
+
# replace all
|
315 |
+
if type(str_to) == list:
|
316 |
+
caption = random.choice(str_to)
|
317 |
+
else:
|
318 |
+
caption = str_to
|
319 |
+
else:
|
320 |
+
caption = caption.replace(str_from, str_to)
|
321 |
+
|
322 |
+
return caption
|
323 |
+
|
324 |
+
def get_input_ids(self, caption):
|
325 |
+
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
|
326 |
+
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
|
327 |
+
|
328 |
+
if self.tokenizer_max_length > self.tokenizer.model_max_length:
|
329 |
+
input_ids = input_ids.squeeze(0)
|
330 |
+
iids_list = []
|
331 |
+
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
332 |
+
# v1
|
333 |
+
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
334 |
+
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
335 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
|
336 |
+
ids_chunk = (input_ids[0].unsqueeze(0),
|
337 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
338 |
+
input_ids[-1].unsqueeze(0))
|
339 |
+
ids_chunk = torch.cat(ids_chunk)
|
340 |
+
iids_list.append(ids_chunk)
|
341 |
+
else:
|
342 |
+
# v2
|
343 |
+
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
344 |
+
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
|
345 |
+
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
|
346 |
+
input_ids[i:i + self.tokenizer.model_max_length - 2],
|
347 |
+
input_ids[-1].unsqueeze(0)) # PAD or EOS
|
348 |
+
ids_chunk = torch.cat(ids_chunk)
|
349 |
+
|
350 |
+
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
351 |
+
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
352 |
+
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
|
353 |
+
ids_chunk[-1] = self.tokenizer.eos_token_id
|
354 |
+
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
355 |
+
if ids_chunk[1] == self.tokenizer.pad_token_id:
|
356 |
+
ids_chunk[1] = self.tokenizer.eos_token_id
|
357 |
+
|
358 |
+
iids_list.append(ids_chunk)
|
359 |
+
|
360 |
+
input_ids = torch.stack(iids_list) # 3,77
|
361 |
+
return input_ids
|
362 |
+
|
363 |
+
def register_image(self, info: ImageInfo):
|
364 |
+
self.image_data[info.image_key] = info
|
365 |
+
|
366 |
+
def make_buckets(self):
|
367 |
+
'''
|
368 |
+
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
|
369 |
+
min_size and max_size are ignored when enable_bucket is False
|
370 |
+
'''
|
371 |
+
print("loading image sizes.")
|
372 |
+
for info in tqdm(self.image_data.values()):
|
373 |
+
if info.image_size is None:
|
374 |
+
info.image_size = self.get_image_size(info.absolute_path)
|
375 |
+
|
376 |
+
if self.enable_bucket:
|
377 |
+
print("make buckets")
|
378 |
+
else:
|
379 |
+
print("prepare dataset")
|
380 |
+
|
381 |
+
# bucketを作成し、画像をbucketに振り分ける
|
382 |
+
if self.enable_bucket:
|
383 |
+
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
|
384 |
+
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
|
385 |
+
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
|
386 |
+
if not self.bucket_no_upscale:
|
387 |
+
self.bucket_manager.make_buckets()
|
388 |
+
else:
|
389 |
+
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
|
390 |
+
|
391 |
+
img_ar_errors = []
|
392 |
+
for image_info in self.image_data.values():
|
393 |
+
image_width, image_height = image_info.image_size
|
394 |
+
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
|
395 |
+
|
396 |
+
# print(image_info.image_key, image_info.bucket_reso)
|
397 |
+
img_ar_errors.append(abs(ar_error))
|
398 |
+
|
399 |
+
self.bucket_manager.sort()
|
400 |
+
else:
|
401 |
+
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
|
402 |
+
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
|
403 |
+
for image_info in self.image_data.values():
|
404 |
+
image_width, image_height = image_info.image_size
|
405 |
+
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
|
406 |
+
|
407 |
+
for image_info in self.image_data.values():
|
408 |
+
for _ in range(image_info.num_repeats):
|
409 |
+
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
|
410 |
+
|
411 |
+
# bucket情報を表示、格納する
|
412 |
+
if self.enable_bucket:
|
413 |
+
self.bucket_info = {"buckets": {}}
|
414 |
+
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
415 |
+
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
|
416 |
+
count = len(bucket)
|
417 |
+
if count > 0:
|
418 |
+
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
419 |
+
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
420 |
+
|
421 |
+
img_ar_errors = np.array(img_ar_errors)
|
422 |
+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
423 |
+
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
424 |
+
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
425 |
+
|
426 |
+
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
427 |
+
self.buckets_indices: List(BucketBatchIndex) = []
|
428 |
+
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
429 |
+
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
430 |
+
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
431 |
+
# そのためバッチサイズを画像種類までに制限する
|
432 |
+
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
433 |
+
# TODO 正則化画像をepochまたがりで利用する仕組み
|
434 |
+
num_of_image_types = len(set(bucket))
|
435 |
+
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
436 |
+
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
437 |
+
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
438 |
+
for batch_index in range(batch_count):
|
439 |
+
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
440 |
+
|
441 |
+
self.shuffle_buckets()
|
442 |
+
self._length = len(self.buckets_indices)
|
443 |
+
|
444 |
+
def shuffle_buckets(self):
|
445 |
+
random.shuffle(self.buckets_indices)
|
446 |
+
self.bucket_manager.shuffle()
|
447 |
+
|
448 |
+
def load_image(self, image_path):
|
449 |
+
image = Image.open(image_path)
|
450 |
+
if not image.mode == "RGB":
|
451 |
+
image = image.convert("RGB")
|
452 |
+
img = np.array(image, np.uint8)
|
453 |
+
return img
|
454 |
+
|
455 |
+
def trim_and_resize_if_required(self, image, reso, resized_size):
|
456 |
+
image_height, image_width = image.shape[0:2]
|
457 |
+
|
458 |
+
if image_width != resized_size[0] or image_height != resized_size[1]:
|
459 |
+
# リサイズする
|
460 |
+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
461 |
+
|
462 |
+
image_height, image_width = image.shape[0:2]
|
463 |
+
if image_width > reso[0]:
|
464 |
+
trim_size = image_width - reso[0]
|
465 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
466 |
+
# print("w", trim_size, p)
|
467 |
+
image = image[:, p:p + reso[0]]
|
468 |
+
if image_height > reso[1]:
|
469 |
+
trim_size = image_height - reso[1]
|
470 |
+
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
|
471 |
+
# print("h", trim_size, p)
|
472 |
+
image = image[p:p + reso[1]]
|
473 |
+
|
474 |
+
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
475 |
+
return image
|
476 |
+
|
477 |
+
def cache_latents(self, vae):
|
478 |
+
# TODO ここを高速化したい
|
479 |
+
print("caching latents.")
|
480 |
+
for info in tqdm(self.image_data.values()):
|
481 |
+
if info.latents_npz is not None:
|
482 |
+
info.latents = self.load_latents_from_npz(info, False)
|
483 |
+
info.latents = torch.FloatTensor(info.latents)
|
484 |
+
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
|
485 |
+
if info.latents_flipped is not None:
|
486 |
+
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
487 |
+
continue
|
488 |
+
|
489 |
+
image = self.load_image(info.absolute_path)
|
490 |
+
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
|
491 |
+
|
492 |
+
img_tensor = self.image_transforms(image)
|
493 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
494 |
+
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
495 |
+
|
496 |
+
if self.flip_aug:
|
497 |
+
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
|
498 |
+
img_tensor = self.image_transforms(image)
|
499 |
+
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
|
500 |
+
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
|
501 |
+
|
502 |
+
def get_image_size(self, image_path):
|
503 |
+
image = Image.open(image_path)
|
504 |
+
return image.size
|
505 |
+
|
506 |
+
def load_image_with_face_info(self, image_path: str):
|
507 |
+
img = self.load_image(image_path)
|
508 |
+
|
509 |
+
face_cx = face_cy = face_w = face_h = 0
|
510 |
+
if self.face_crop_aug_range is not None:
|
511 |
+
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
|
512 |
+
if len(tokens) >= 5:
|
513 |
+
face_cx = int(tokens[-4])
|
514 |
+
face_cy = int(tokens[-3])
|
515 |
+
face_w = int(tokens[-2])
|
516 |
+
face_h = int(tokens[-1])
|
517 |
+
|
518 |
+
return img, face_cx, face_cy, face_w, face_h
|
519 |
+
|
520 |
+
# いい感じに切り出す
|
521 |
+
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
|
522 |
+
height, width = image.shape[0:2]
|
523 |
+
if height == self.height and width == self.width:
|
524 |
+
return image
|
525 |
+
|
526 |
+
# 画像サイズはsizeより大きいのでリサイズする
|
527 |
+
face_size = max(face_w, face_h)
|
528 |
+
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
|
529 |
+
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
|
530 |
+
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
|
531 |
+
if min_scale >= max_scale: # range指定がmin==max
|
532 |
+
scale = min_scale
|
533 |
+
else:
|
534 |
+
scale = random.uniform(min_scale, max_scale)
|
535 |
+
|
536 |
+
nh = int(height * scale + .5)
|
537 |
+
nw = int(width * scale + .5)
|
538 |
+
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
539 |
+
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
540 |
+
face_cx = int(face_cx * scale + .5)
|
541 |
+
face_cy = int(face_cy * scale + .5)
|
542 |
+
height, width = nh, nw
|
543 |
+
|
544 |
+
# 顔を中心として448*640とかへ切り出す
|
545 |
+
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
|
546 |
+
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
|
547 |
+
|
548 |
+
if self.random_crop:
|
549 |
+
# 背景も含めるために顔を中心に置く確率を高めつつずらす
|
550 |
+
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
|
551 |
+
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
|
552 |
+
else:
|
553 |
+
# range指定があるときのみ、すこしだけランダムに(わりと適当)
|
554 |
+
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
|
555 |
+
if face_size > self.size // 10 and face_size >= 40:
|
556 |
+
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
|
557 |
+
|
558 |
+
p1 = max(0, min(p1, length - target_size))
|
559 |
+
|
560 |
+
if axis == 0:
|
561 |
+
image = image[p1:p1 + target_size, :]
|
562 |
+
else:
|
563 |
+
image = image[:, p1:p1 + target_size]
|
564 |
+
|
565 |
+
return image
|
566 |
+
|
567 |
+
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
568 |
+
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
569 |
+
if npz_file is None:
|
570 |
+
return None
|
571 |
+
return np.load(npz_file)['arr_0']
|
572 |
+
|
573 |
+
def __len__(self):
|
574 |
+
return self._length
|
575 |
+
|
576 |
+
def __getitem__(self, index):
|
577 |
+
if index == 0:
|
578 |
+
self.shuffle_buckets()
|
579 |
+
|
580 |
+
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
|
581 |
+
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
582 |
+
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
583 |
+
|
584 |
+
loss_weights = []
|
585 |
+
captions = []
|
586 |
+
input_ids_list = []
|
587 |
+
latents_list = []
|
588 |
+
images = []
|
589 |
+
|
590 |
+
for image_key in bucket[image_index:image_index + bucket_batch_size]:
|
591 |
+
image_info = self.image_data[image_key]
|
592 |
+
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
593 |
+
|
594 |
+
# image/latentsを処理する
|
595 |
+
if image_info.latents is not None:
|
596 |
+
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
|
597 |
+
image = None
|
598 |
+
elif image_info.latents_npz is not None:
|
599 |
+
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
|
600 |
+
latents = torch.FloatTensor(latents)
|
601 |
+
image = None
|
602 |
+
else:
|
603 |
+
# 画像を読み込み、必要ならcropする
|
604 |
+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
|
605 |
+
im_h, im_w = img.shape[0:2]
|
606 |
+
|
607 |
+
if self.enable_bucket:
|
608 |
+
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
|
609 |
+
else:
|
610 |
+
if face_cx > 0: # 顔位置情報あり
|
611 |
+
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
|
612 |
+
elif im_h > self.height or im_w > self.width:
|
613 |
+
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
|
614 |
+
if im_h > self.height:
|
615 |
+
p = random.randint(0, im_h - self.height)
|
616 |
+
img = img[p:p + self.height]
|
617 |
+
if im_w > self.width:
|
618 |
+
p = random.randint(0, im_w - self.width)
|
619 |
+
img = img[:, p:p + self.width]
|
620 |
+
|
621 |
+
im_h, im_w = img.shape[0:2]
|
622 |
+
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
623 |
+
|
624 |
+
# augmentation
|
625 |
+
if self.aug is not None:
|
626 |
+
img = self.aug(image=img)['image']
|
627 |
+
|
628 |
+
latents = None
|
629 |
+
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
|
630 |
+
|
631 |
+
images.append(image)
|
632 |
+
latents_list.append(latents)
|
633 |
+
|
634 |
+
caption = self.process_caption(image_info.caption)
|
635 |
+
captions.append(caption)
|
636 |
+
if not self.token_padding_disabled: # this option might be omitted in future
|
637 |
+
input_ids_list.append(self.get_input_ids(caption))
|
638 |
+
|
639 |
+
example = {}
|
640 |
+
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
641 |
+
|
642 |
+
if self.token_padding_disabled:
|
643 |
+
# padding=True means pad in the batch
|
644 |
+
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
645 |
+
else:
|
646 |
+
# batch processing seems to be good
|
647 |
+
example['input_ids'] = torch.stack(input_ids_list)
|
648 |
+
|
649 |
+
if images[0] is not None:
|
650 |
+
images = torch.stack(images)
|
651 |
+
images = images.to(memory_format=torch.contiguous_format).float()
|
652 |
+
else:
|
653 |
+
images = None
|
654 |
+
example['images'] = images
|
655 |
+
|
656 |
+
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
|
657 |
+
|
658 |
+
if self.debug_dataset:
|
659 |
+
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
|
660 |
+
example['captions'] = captions
|
661 |
+
return example
|
662 |
+
|
663 |
+
|
664 |
+
class DreamBoothDataset(BaseDataset):
|
665 |
+
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
|
666 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
667 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
668 |
+
|
669 |
+
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
670 |
+
|
671 |
+
self.batch_size = batch_size
|
672 |
+
self.size = min(self.width, self.height) # 短いほう
|
673 |
+
self.prior_loss_weight = prior_loss_weight
|
674 |
+
self.latents_cache = None
|
675 |
+
|
676 |
+
self.enable_bucket = enable_bucket
|
677 |
+
if self.enable_bucket:
|
678 |
+
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
679 |
+
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
680 |
+
self.min_bucket_reso = min_bucket_reso
|
681 |
+
self.max_bucket_reso = max_bucket_reso
|
682 |
+
self.bucket_reso_steps = bucket_reso_steps
|
683 |
+
self.bucket_no_upscale = bucket_no_upscale
|
684 |
+
else:
|
685 |
+
self.min_bucket_reso = None
|
686 |
+
self.max_bucket_reso = None
|
687 |
+
self.bucket_reso_steps = None # この情報は使われない
|
688 |
+
self.bucket_no_upscale = False
|
689 |
+
|
690 |
+
def read_caption(img_path):
|
691 |
+
# captionの候補ファイル名を作る
|
692 |
+
base_name = os.path.splitext(img_path)[0]
|
693 |
+
base_name_face_det = base_name
|
694 |
+
tokens = base_name.split("_")
|
695 |
+
if len(tokens) >= 5:
|
696 |
+
base_name_face_det = "_".join(tokens[:-4])
|
697 |
+
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
698 |
+
|
699 |
+
caption = None
|
700 |
+
for cap_path in cap_paths:
|
701 |
+
if os.path.isfile(cap_path):
|
702 |
+
with open(cap_path, "rt", encoding='utf-8') as f:
|
703 |
+
try:
|
704 |
+
lines = f.readlines()
|
705 |
+
except UnicodeDecodeError as e:
|
706 |
+
print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
707 |
+
raise e
|
708 |
+
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
709 |
+
caption = lines[0].strip()
|
710 |
+
break
|
711 |
+
return caption
|
712 |
+
|
713 |
+
def load_dreambooth_dir(dir):
|
714 |
+
if not os.path.isdir(dir):
|
715 |
+
# print(f"ignore file: {dir}")
|
716 |
+
return 0, [], []
|
717 |
+
|
718 |
+
tokens = os.path.basename(dir).split('_')
|
719 |
+
try:
|
720 |
+
n_repeats = int(tokens[0])
|
721 |
+
except ValueError as e:
|
722 |
+
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
|
723 |
+
return 0, [], []
|
724 |
+
|
725 |
+
caption_by_folder = '_'.join(tokens[1:])
|
726 |
+
img_paths = glob_images(dir, "*")
|
727 |
+
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
|
728 |
+
|
729 |
+
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
730 |
+
captions = []
|
731 |
+
for img_path in img_paths:
|
732 |
+
cap_for_img = read_caption(img_path)
|
733 |
+
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
734 |
+
|
735 |
+
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
|
736 |
+
|
737 |
+
return n_repeats, img_paths, captions
|
738 |
+
|
739 |
+
print("prepare train images.")
|
740 |
+
train_dirs = os.listdir(train_data_dir)
|
741 |
+
num_train_images = 0
|
742 |
+
for dir in train_dirs:
|
743 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
|
744 |
+
num_train_images += n_repeats * len(img_paths)
|
745 |
+
|
746 |
+
for img_path, caption in zip(img_paths, captions):
|
747 |
+
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
748 |
+
self.register_image(info)
|
749 |
+
|
750 |
+
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
751 |
+
|
752 |
+
print(f"{num_train_images} train images with repeating.")
|
753 |
+
self.num_train_images = num_train_images
|
754 |
+
|
755 |
+
# reg imageは数を数えて学習画像と同じ枚数にする
|
756 |
+
num_reg_images = 0
|
757 |
+
if reg_data_dir:
|
758 |
+
print("prepare reg images.")
|
759 |
+
reg_infos: List[ImageInfo] = []
|
760 |
+
|
761 |
+
reg_dirs = os.listdir(reg_data_dir)
|
762 |
+
for dir in reg_dirs:
|
763 |
+
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
|
764 |
+
num_reg_images += n_repeats * len(img_paths)
|
765 |
+
|
766 |
+
for img_path, caption in zip(img_paths, captions):
|
767 |
+
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
768 |
+
reg_infos.append(info)
|
769 |
+
|
770 |
+
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
771 |
+
|
772 |
+
print(f"{num_reg_images} reg images.")
|
773 |
+
if num_train_images < num_reg_images:
|
774 |
+
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
|
775 |
+
|
776 |
+
if num_reg_images == 0:
|
777 |
+
print("no regularization images / 正則化画像が見つかりませんでした")
|
778 |
+
else:
|
779 |
+
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
780 |
+
n = 0
|
781 |
+
first_loop = True
|
782 |
+
while n < num_train_images:
|
783 |
+
for info in reg_infos:
|
784 |
+
if first_loop:
|
785 |
+
self.register_image(info)
|
786 |
+
n += info.num_repeats
|
787 |
+
else:
|
788 |
+
info.num_repeats += 1
|
789 |
+
n += 1
|
790 |
+
if n >= num_train_images:
|
791 |
+
break
|
792 |
+
first_loop = False
|
793 |
+
|
794 |
+
self.num_reg_images = num_reg_images
|
795 |
+
|
796 |
+
|
797 |
+
class FineTuningDataset(BaseDataset):
|
798 |
+
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
|
799 |
+
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
|
800 |
+
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
|
801 |
+
|
802 |
+
# メタデータを読み込む
|
803 |
+
if os.path.exists(json_file_name):
|
804 |
+
print(f"loading existing metadata: {json_file_name}")
|
805 |
+
with open(json_file_name, "rt", encoding='utf-8') as f:
|
806 |
+
metadata = json.load(f)
|
807 |
+
else:
|
808 |
+
raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}")
|
809 |
+
|
810 |
+
self.metadata = metadata
|
811 |
+
self.train_data_dir = train_data_dir
|
812 |
+
self.batch_size = batch_size
|
813 |
+
|
814 |
+
tags_list = []
|
815 |
+
for image_key, img_md in metadata.items():
|
816 |
+
# path情報を作る
|
817 |
+
if os.path.exists(image_key):
|
818 |
+
abs_path = image_key
|
819 |
+
else:
|
820 |
+
# わりといい加減だがいい方法が思いつかん
|
821 |
+
abs_path = glob_images(train_data_dir, image_key)
|
822 |
+
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
823 |
+
abs_path = abs_path[0]
|
824 |
+
|
825 |
+
caption = img_md.get('caption')
|
826 |
+
tags = img_md.get('tags')
|
827 |
+
if caption is None:
|
828 |
+
caption = tags
|
829 |
+
elif tags is not None and len(tags) > 0:
|
830 |
+
caption = caption + ', ' + tags
|
831 |
+
tags_list.append(tags)
|
832 |
+
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
833 |
+
|
834 |
+
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
|
835 |
+
image_info.image_size = img_md.get('train_resolution')
|
836 |
+
|
837 |
+
if not self.color_aug and not self.random_crop:
|
838 |
+
# if npz exists, use them
|
839 |
+
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
|
840 |
+
|
841 |
+
self.register_image(image_info)
|
842 |
+
self.num_train_images = len(metadata) * dataset_repeats
|
843 |
+
self.num_reg_images = 0
|
844 |
+
|
845 |
+
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
846 |
+
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
847 |
+
|
848 |
+
# check existence of all npz files
|
849 |
+
use_npz_latents = not (self.color_aug or self.random_crop)
|
850 |
+
if use_npz_latents:
|
851 |
+
npz_any = False
|
852 |
+
npz_all = True
|
853 |
+
for image_info in self.image_data.values():
|
854 |
+
has_npz = image_info.latents_npz is not None
|
855 |
+
npz_any = npz_any or has_npz
|
856 |
+
|
857 |
+
if self.flip_aug:
|
858 |
+
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
859 |
+
npz_all = npz_all and has_npz
|
860 |
+
|
861 |
+
if npz_any and not npz_all:
|
862 |
+
break
|
863 |
+
|
864 |
+
if not npz_any:
|
865 |
+
use_npz_latents = False
|
866 |
+
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
867 |
+
elif not npz_all:
|
868 |
+
use_npz_latents = False
|
869 |
+
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
|
870 |
+
if self.flip_aug:
|
871 |
+
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
872 |
+
# else:
|
873 |
+
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
874 |
+
|
875 |
+
# check min/max bucket size
|
876 |
+
sizes = set()
|
877 |
+
resos = set()
|
878 |
+
for image_info in self.image_data.values():
|
879 |
+
if image_info.image_size is None:
|
880 |
+
sizes = None # not calculated
|
881 |
+
break
|
882 |
+
sizes.add(image_info.image_size[0])
|
883 |
+
sizes.add(image_info.image_size[1])
|
884 |
+
resos.add(tuple(image_info.image_size))
|
885 |
+
|
886 |
+
if sizes is None:
|
887 |
+
if use_npz_latents:
|
888 |
+
use_npz_latents = False
|
889 |
+
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
|
890 |
+
|
891 |
+
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
892 |
+
|
893 |
+
self.enable_bucket = enable_bucket
|
894 |
+
if self.enable_bucket:
|
895 |
+
self.min_bucket_reso = min_bucket_reso
|
896 |
+
self.max_bucket_reso = max_bucket_reso
|
897 |
+
self.bucket_reso_steps = bucket_reso_steps
|
898 |
+
self.bucket_no_upscale = bucket_no_upscale
|
899 |
+
else:
|
900 |
+
if not enable_bucket:
|
901 |
+
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
902 |
+
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
903 |
+
self.enable_bucket = True
|
904 |
+
|
905 |
+
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
906 |
+
|
907 |
+
# bucket情報を��期化しておく、make_bucketsで再作成しない
|
908 |
+
self.bucket_manager = BucketManager(False, None, None, None, None)
|
909 |
+
self.bucket_manager.set_predefined_resos(resos)
|
910 |
+
|
911 |
+
# npz情報をきれいにしておく
|
912 |
+
if not use_npz_latents:
|
913 |
+
for image_info in self.image_data.values():
|
914 |
+
image_info.latents_npz = image_info.latents_npz_flipped = None
|
915 |
+
|
916 |
+
def image_key_to_npz_file(self, image_key):
|
917 |
+
base_name = os.path.splitext(image_key)[0]
|
918 |
+
npz_file_norm = base_name + '.npz'
|
919 |
+
|
920 |
+
if os.path.exists(npz_file_norm):
|
921 |
+
# image_key is full path
|
922 |
+
npz_file_flip = base_name + '_flip.npz'
|
923 |
+
if not os.path.exists(npz_file_flip):
|
924 |
+
npz_file_flip = None
|
925 |
+
return npz_file_norm, npz_file_flip
|
926 |
+
|
927 |
+
# image_key is relative path
|
928 |
+
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
|
929 |
+
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
|
930 |
+
|
931 |
+
if not os.path.exists(npz_file_norm):
|
932 |
+
npz_file_norm = None
|
933 |
+
npz_file_flip = None
|
934 |
+
elif not os.path.exists(npz_file_flip):
|
935 |
+
npz_file_flip = None
|
936 |
+
|
937 |
+
return npz_file_norm, npz_file_flip
|
938 |
+
|
939 |
+
|
940 |
+
def debug_dataset(train_dataset, show_input_ids=False):
|
941 |
+
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
942 |
+
print("Escape for exit. / Escキーで中断、終了します")
|
943 |
+
|
944 |
+
train_dataset.set_current_epoch(1)
|
945 |
+
k = 0
|
946 |
+
for i, example in enumerate(train_dataset):
|
947 |
+
if example['latents'] is not None:
|
948 |
+
print(f"sample has latents from npz file: {example['latents'].size()}")
|
949 |
+
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
950 |
+
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
|
951 |
+
if show_input_ids:
|
952 |
+
print(f"input ids: {iid}")
|
953 |
+
if example['images'] is not None:
|
954 |
+
im = example['images'][j]
|
955 |
+
print(f"image size: {im.size()}")
|
956 |
+
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
957 |
+
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
|
958 |
+
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
|
959 |
+
if os.name == 'nt': # only windows
|
960 |
+
cv2.imshow("img", im)
|
961 |
+
k = cv2.waitKey()
|
962 |
+
cv2.destroyAllWindows()
|
963 |
+
if k == 27:
|
964 |
+
break
|
965 |
+
if k == 27 or (example['images'] is None and i >= 8):
|
966 |
+
break
|
967 |
+
|
968 |
+
|
969 |
+
def glob_images(directory, base="*"):
|
970 |
+
img_paths = []
|
971 |
+
for ext in IMAGE_EXTENSIONS:
|
972 |
+
if base == '*':
|
973 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
974 |
+
else:
|
975 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
976 |
+
# img_paths = list(set(img_paths)) # 重複を排除
|
977 |
+
# img_paths.sort()
|
978 |
+
return img_paths
|
979 |
+
|
980 |
+
|
981 |
+
def glob_images_pathlib(dir_path, recursive):
|
982 |
+
image_paths = []
|
983 |
+
if recursive:
|
984 |
+
for ext in IMAGE_EXTENSIONS:
|
985 |
+
image_paths += list(dir_path.rglob('*' + ext))
|
986 |
+
else:
|
987 |
+
for ext in IMAGE_EXTENSIONS:
|
988 |
+
image_paths += list(dir_path.glob('*' + ext))
|
989 |
+
# image_paths = list(set(image_paths)) # 重複を排除
|
990 |
+
# image_paths.sort()
|
991 |
+
return image_paths
|
992 |
+
|
993 |
+
# endregion
|
994 |
+
|
995 |
+
|
996 |
+
# region モジュール入れ替え部
|
997 |
+
"""
|
998 |
+
高速化のためのモジュール入れ替え
|
999 |
+
"""
|
1000 |
+
|
1001 |
+
# FlashAttentionを使うCrossAttention
|
1002 |
+
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
1003 |
+
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
1004 |
+
|
1005 |
+
# constants
|
1006 |
+
|
1007 |
+
EPSILON = 1e-6
|
1008 |
+
|
1009 |
+
# helper functions
|
1010 |
+
|
1011 |
+
|
1012 |
+
def exists(val):
|
1013 |
+
return val is not None
|
1014 |
+
|
1015 |
+
|
1016 |
+
def default(val, d):
|
1017 |
+
return val if exists(val) else d
|
1018 |
+
|
1019 |
+
|
1020 |
+
def model_hash(filename):
|
1021 |
+
"""Old model hash used by stable-diffusion-webui"""
|
1022 |
+
try:
|
1023 |
+
with open(filename, "rb") as file:
|
1024 |
+
m = hashlib.sha256()
|
1025 |
+
|
1026 |
+
file.seek(0x100000)
|
1027 |
+
m.update(file.read(0x10000))
|
1028 |
+
return m.hexdigest()[0:8]
|
1029 |
+
except FileNotFoundError:
|
1030 |
+
return 'NOFILE'
|
1031 |
+
|
1032 |
+
|
1033 |
+
def calculate_sha256(filename):
|
1034 |
+
"""New model hash used by stable-diffusion-webui"""
|
1035 |
+
hash_sha256 = hashlib.sha256()
|
1036 |
+
blksize = 1024 * 1024
|
1037 |
+
|
1038 |
+
with open(filename, "rb") as f:
|
1039 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
1040 |
+
hash_sha256.update(chunk)
|
1041 |
+
|
1042 |
+
return hash_sha256.hexdigest()
|
1043 |
+
|
1044 |
+
|
1045 |
+
def precalculate_safetensors_hashes(tensors, metadata):
|
1046 |
+
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
1047 |
+
save time on indexing the model later."""
|
1048 |
+
|
1049 |
+
# Because writing user metadata to the file can change the result of
|
1050 |
+
# sd_models.model_hash(), only retain the training metadata for purposes of
|
1051 |
+
# calculating the hash, as they are meant to be immutable
|
1052 |
+
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
1053 |
+
|
1054 |
+
bytes = safetensors.torch.save(tensors, metadata)
|
1055 |
+
b = BytesIO(bytes)
|
1056 |
+
|
1057 |
+
model_hash = addnet_hash_safetensors(b)
|
1058 |
+
legacy_hash = addnet_hash_legacy(b)
|
1059 |
+
return model_hash, legacy_hash
|
1060 |
+
|
1061 |
+
|
1062 |
+
def addnet_hash_legacy(b):
|
1063 |
+
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1064 |
+
m = hashlib.sha256()
|
1065 |
+
|
1066 |
+
b.seek(0x100000)
|
1067 |
+
m.update(b.read(0x10000))
|
1068 |
+
return m.hexdigest()[0:8]
|
1069 |
+
|
1070 |
+
|
1071 |
+
def addnet_hash_safetensors(b):
|
1072 |
+
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
1073 |
+
hash_sha256 = hashlib.sha256()
|
1074 |
+
blksize = 1024 * 1024
|
1075 |
+
|
1076 |
+
b.seek(0)
|
1077 |
+
header = b.read(8)
|
1078 |
+
n = int.from_bytes(header, "little")
|
1079 |
+
|
1080 |
+
offset = n + 8
|
1081 |
+
b.seek(offset)
|
1082 |
+
for chunk in iter(lambda: b.read(blksize), b""):
|
1083 |
+
hash_sha256.update(chunk)
|
1084 |
+
|
1085 |
+
return hash_sha256.hexdigest()
|
1086 |
+
|
1087 |
+
|
1088 |
+
# flash attention forwards and backwards
|
1089 |
+
|
1090 |
+
# https://arxiv.org/abs/2205.14135
|
1091 |
+
|
1092 |
+
|
1093 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
1094 |
+
@ staticmethod
|
1095 |
+
@ torch.no_grad()
|
1096 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
1097 |
+
""" Algorithm 2 in the paper """
|
1098 |
+
|
1099 |
+
device = q.device
|
1100 |
+
dtype = q.dtype
|
1101 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1102 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1103 |
+
|
1104 |
+
o = torch.zeros_like(q)
|
1105 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
1106 |
+
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
|
1107 |
+
|
1108 |
+
scale = (q.shape[-1] ** -0.5)
|
1109 |
+
|
1110 |
+
if not exists(mask):
|
1111 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
1112 |
+
else:
|
1113 |
+
mask = rearrange(mask, 'b n -> b 1 1 n')
|
1114 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
1115 |
+
|
1116 |
+
row_splits = zip(
|
1117 |
+
q.split(q_bucket_size, dim=-2),
|
1118 |
+
o.split(q_bucket_size, dim=-2),
|
1119 |
+
mask,
|
1120 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
1121 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
1125 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1126 |
+
|
1127 |
+
col_splits = zip(
|
1128 |
+
k.split(k_bucket_size, dim=-2),
|
1129 |
+
v.split(k_bucket_size, dim=-2),
|
1130 |
+
)
|
1131 |
+
|
1132 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
1133 |
+
k_start_index = k_ind * k_bucket_size
|
1134 |
+
|
1135 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1136 |
+
|
1137 |
+
if exists(row_mask):
|
1138 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
1139 |
+
|
1140 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1141 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1142 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1143 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1144 |
+
|
1145 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
1146 |
+
attn_weights -= block_row_maxes
|
1147 |
+
exp_weights = torch.exp(attn_weights)
|
1148 |
+
|
1149 |
+
if exists(row_mask):
|
1150 |
+
exp_weights.masked_fill_(~row_mask, 0.)
|
1151 |
+
|
1152 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
|
1153 |
+
|
1154 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
1155 |
+
|
1156 |
+
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
|
1157 |
+
|
1158 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
1159 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
1160 |
+
|
1161 |
+
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
|
1162 |
+
|
1163 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
|
1164 |
+
|
1165 |
+
row_maxes.copy_(new_row_maxes)
|
1166 |
+
row_sums.copy_(new_row_sums)
|
1167 |
+
|
1168 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
1169 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
1170 |
+
|
1171 |
+
return o
|
1172 |
+
|
1173 |
+
@ staticmethod
|
1174 |
+
@ torch.no_grad()
|
1175 |
+
def backward(ctx, do):
|
1176 |
+
""" Algorithm 4 in the paper """
|
1177 |
+
|
1178 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
1179 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
1180 |
+
|
1181 |
+
device = q.device
|
1182 |
+
|
1183 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
1184 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
1185 |
+
|
1186 |
+
dq = torch.zeros_like(q)
|
1187 |
+
dk = torch.zeros_like(k)
|
1188 |
+
dv = torch.zeros_like(v)
|
1189 |
+
|
1190 |
+
row_splits = zip(
|
1191 |
+
q.split(q_bucket_size, dim=-2),
|
1192 |
+
o.split(q_bucket_size, dim=-2),
|
1193 |
+
do.split(q_bucket_size, dim=-2),
|
1194 |
+
mask,
|
1195 |
+
l.split(q_bucket_size, dim=-2),
|
1196 |
+
m.split(q_bucket_size, dim=-2),
|
1197 |
+
dq.split(q_bucket_size, dim=-2)
|
1198 |
+
)
|
1199 |
+
|
1200 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
1201 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
1202 |
+
|
1203 |
+
col_splits = zip(
|
1204 |
+
k.split(k_bucket_size, dim=-2),
|
1205 |
+
v.split(k_bucket_size, dim=-2),
|
1206 |
+
dk.split(k_bucket_size, dim=-2),
|
1207 |
+
dv.split(k_bucket_size, dim=-2),
|
1208 |
+
)
|
1209 |
+
|
1210 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
1211 |
+
k_start_index = k_ind * k_bucket_size
|
1212 |
+
|
1213 |
+
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
|
1214 |
+
|
1215 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
1216 |
+
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
|
1217 |
+
device=device).triu(q_start_index - k_start_index + 1)
|
1218 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
1219 |
+
|
1220 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
1221 |
+
|
1222 |
+
if exists(row_mask):
|
1223 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.)
|
1224 |
+
|
1225 |
+
p = exp_attn_weights / lc
|
1226 |
+
|
1227 |
+
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
|
1228 |
+
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
|
1229 |
+
|
1230 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
1231 |
+
ds = p * scale * (dp - D)
|
1232 |
+
|
1233 |
+
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
|
1234 |
+
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
|
1235 |
+
|
1236 |
+
dqc.add_(dq_chunk)
|
1237 |
+
dkc.add_(dk_chunk)
|
1238 |
+
dvc.add_(dv_chunk)
|
1239 |
+
|
1240 |
+
return dq, dk, dv, None, None, None, None
|
1241 |
+
|
1242 |
+
|
1243 |
+
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
1244 |
+
if mem_eff_attn:
|
1245 |
+
replace_unet_cross_attn_to_memory_efficient()
|
1246 |
+
elif xformers:
|
1247 |
+
replace_unet_cross_attn_to_xformers()
|
1248 |
+
|
1249 |
+
|
1250 |
+
def replace_unet_cross_attn_to_memory_efficient():
|
1251 |
+
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
1252 |
+
flash_func = FlashAttentionFunction
|
1253 |
+
|
1254 |
+
def forward_flash_attn(self, x, context=None, mask=None):
|
1255 |
+
q_bucket_size = 512
|
1256 |
+
k_bucket_size = 1024
|
1257 |
+
|
1258 |
+
h = self.heads
|
1259 |
+
q = self.to_q(x)
|
1260 |
+
|
1261 |
+
context = context if context is not None else x
|
1262 |
+
context = context.to(x.dtype)
|
1263 |
+
|
1264 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1265 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1266 |
+
context_k = context_k.to(x.dtype)
|
1267 |
+
context_v = context_v.to(x.dtype)
|
1268 |
+
else:
|
1269 |
+
context_k = context
|
1270 |
+
context_v = context
|
1271 |
+
|
1272 |
+
k = self.to_k(context_k)
|
1273 |
+
v = self.to_v(context_v)
|
1274 |
+
del context, x
|
1275 |
+
|
1276 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
1277 |
+
|
1278 |
+
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
|
1279 |
+
|
1280 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
1281 |
+
|
1282 |
+
# diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
|
1283 |
+
out = self.to_out[0](out)
|
1284 |
+
out = self.to_out[1](out)
|
1285 |
+
return out
|
1286 |
+
|
1287 |
+
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
|
1288 |
+
|
1289 |
+
|
1290 |
+
def replace_unet_cross_attn_to_xformers():
|
1291 |
+
print("Replace CrossAttention.forward to use xformers")
|
1292 |
+
try:
|
1293 |
+
import xformers.ops
|
1294 |
+
except ImportError:
|
1295 |
+
raise ImportError("No xformers / xformersがインストールされていないようです")
|
1296 |
+
|
1297 |
+
def forward_xformers(self, x, context=None, mask=None):
|
1298 |
+
h = self.heads
|
1299 |
+
q_in = self.to_q(x)
|
1300 |
+
|
1301 |
+
context = default(context, x)
|
1302 |
+
context = context.to(x.dtype)
|
1303 |
+
|
1304 |
+
if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
|
1305 |
+
context_k, context_v = self.hypernetwork.forward(x, context)
|
1306 |
+
context_k = context_k.to(x.dtype)
|
1307 |
+
context_v = context_v.to(x.dtype)
|
1308 |
+
else:
|
1309 |
+
context_k = context
|
1310 |
+
context_v = context
|
1311 |
+
|
1312 |
+
k_in = self.to_k(context_k)
|
1313 |
+
v_in = self.to_v(context_v)
|
1314 |
+
|
1315 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
1316 |
+
del q_in, k_in, v_in
|
1317 |
+
|
1318 |
+
q = q.contiguous()
|
1319 |
+
k = k.contiguous()
|
1320 |
+
v = v.contiguous()
|
1321 |
+
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
|
1322 |
+
|
1323 |
+
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
1324 |
+
|
1325 |
+
# diffusers 0.7.0~
|
1326 |
+
out = self.to_out[0](out)
|
1327 |
+
out = self.to_out[1](out)
|
1328 |
+
return out
|
1329 |
+
|
1330 |
+
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
1331 |
+
# endregion
|
1332 |
+
|
1333 |
+
|
1334 |
+
# region arguments
|
1335 |
+
|
1336 |
+
def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
1337 |
+
# for pretrained models
|
1338 |
+
parser.add_argument("--v2", action='store_true',
|
1339 |
+
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
|
1340 |
+
parser.add_argument("--v_parameterization", action='store_true',
|
1341 |
+
help='enable v-parameterization training / v-parameterization学習を有効にする')
|
1342 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
|
1343 |
+
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
1344 |
+
|
1345 |
+
|
1346 |
+
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
1347 |
+
parser.add_argument("--output_dir", type=str, default=None,
|
1348 |
+
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
1349 |
+
parser.add_argument("--output_name", type=str, default=None,
|
1350 |
+
help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
|
1351 |
+
parser.add_argument("--save_precision", type=str, default=None,
|
1352 |
+
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
1353 |
+
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
1354 |
+
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
1355 |
+
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
1356 |
+
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
1357 |
+
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
1358 |
+
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
1359 |
+
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
1360 |
+
parser.add_argument("--save_state", action="store_true",
|
1361 |
+
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
1362 |
+
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
1363 |
+
|
1364 |
+
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
1365 |
+
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
1366 |
+
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
1367 |
+
parser.add_argument("--use_8bit_adam", action="store_true",
|
1368 |
+
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
1369 |
+
parser.add_argument("--mem_eff_attn", action="store_true",
|
1370 |
+
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
1371 |
+
parser.add_argument("--xformers", action="store_true",
|
1372 |
+
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
1373 |
+
parser.add_argument("--vae", type=str, default=None,
|
1374 |
+
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
1375 |
+
|
1376 |
+
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
1377 |
+
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
1378 |
+
parser.add_argument("--max_train_epochs", type=int, default=None,
|
1379 |
+
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
1380 |
+
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
1381 |
+
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
1382 |
+
parser.add_argument("--persistent_data_loader_workers", action="store_true",
|
1383 |
+
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
|
1384 |
+
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
1385 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
1386 |
+
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
1387 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
|
1388 |
+
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
|
1389 |
+
parser.add_argument("--mixed_precision", type=str, default="no",
|
1390 |
+
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
|
1391 |
+
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
|
1392 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
1393 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
1394 |
+
parser.add_argument("--logging_dir", type=str, default=None,
|
1395 |
+
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
1396 |
+
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
1397 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
1398 |
+
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
1399 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
1400 |
+
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
1401 |
+
|
1402 |
+
if support_dreambooth:
|
1403 |
+
# DreamBooth training
|
1404 |
+
parser.add_argument("--prior_loss_weight", type=float, default=1.0,
|
1405 |
+
help="loss weight for regularization images / 正則化画像のlossの重み")
|
1406 |
+
|
1407 |
+
|
1408 |
+
def verify_training_args(args: argparse.Namespace):
|
1409 |
+
if args.v_parameterization and not args.v2:
|
1410 |
+
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
1411 |
+
if args.v2 and args.clip_skip is not None:
|
1412 |
+
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
1413 |
+
|
1414 |
+
|
1415 |
+
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
1416 |
+
# dataset common
|
1417 |
+
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
1418 |
+
parser.add_argument("--shuffle_caption", action="store_true",
|
1419 |
+
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
|
1420 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
|
1421 |
+
parser.add_argument("--caption_extention", type=str, default=None,
|
1422 |
+
help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
|
1423 |
+
parser.add_argument("--keep_tokens", type=int, default=None,
|
1424 |
+
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
|
1425 |
+
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
|
1426 |
+
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
|
1427 |
+
parser.add_argument("--face_crop_aug_range", type=str, default=None,
|
1428 |
+
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
|
1429 |
+
parser.add_argument("--random_crop", action="store_true",
|
1430 |
+
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
|
1431 |
+
parser.add_argument("--debug_dataset", action="store_true",
|
1432 |
+
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
|
1433 |
+
parser.add_argument("--resolution", type=str, default=None,
|
1434 |
+
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
|
1435 |
+
parser.add_argument("--cache_latents", action="store_true",
|
1436 |
+
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
|
1437 |
+
parser.add_argument("--enable_bucket", action="store_true",
|
1438 |
+
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
|
1439 |
+
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
1440 |
+
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
1441 |
+
parser.add_argument("--bucket_reso_steps", type=int, default=64,
|
1442 |
+
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
1443 |
+
parser.add_argument("--bucket_no_upscale", action="store_true",
|
1444 |
+
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
1445 |
+
|
1446 |
+
if support_caption_dropout:
|
1447 |
+
# Textual Inversion はcaptionのdropoutをsupportしない
|
1448 |
+
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
1449 |
+
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
1450 |
+
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
1451 |
+
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
1452 |
+
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
1453 |
+
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
1454 |
+
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
1455 |
+
|
1456 |
+
if support_dreambooth:
|
1457 |
+
# DreamBooth dataset
|
1458 |
+
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
1459 |
+
|
1460 |
+
if support_caption:
|
1461 |
+
# caption dataset
|
1462 |
+
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
|
1463 |
+
parser.add_argument("--dataset_repeats", type=int, default=1,
|
1464 |
+
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
|
1465 |
+
|
1466 |
+
|
1467 |
+
def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
1468 |
+
parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"],
|
1469 |
+
help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)")
|
1470 |
+
parser.add_argument("--use_safetensors", action='store_true',
|
1471 |
+
help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)")
|
1472 |
+
|
1473 |
+
# endregion
|
1474 |
+
|
1475 |
+
# region utils
|
1476 |
+
|
1477 |
+
|
1478 |
+
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
1479 |
+
# backward compatibility
|
1480 |
+
if args.caption_extention is not None:
|
1481 |
+
args.caption_extension = args.caption_extention
|
1482 |
+
args.caption_extention = None
|
1483 |
+
|
1484 |
+
if args.cache_latents:
|
1485 |
+
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
|
1486 |
+
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
|
1487 |
+
|
1488 |
+
# assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください"
|
1489 |
+
if args.resolution is not None:
|
1490 |
+
args.resolution = tuple([int(r) for r in args.resolution.split(',')])
|
1491 |
+
if len(args.resolution) == 1:
|
1492 |
+
args.resolution = (args.resolution[0], args.resolution[0])
|
1493 |
+
assert len(args.resolution) == 2, \
|
1494 |
+
f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
|
1495 |
+
|
1496 |
+
if args.face_crop_aug_range is not None:
|
1497 |
+
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
|
1498 |
+
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
|
1499 |
+
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
|
1500 |
+
else:
|
1501 |
+
args.face_crop_aug_range = None
|
1502 |
+
|
1503 |
+
if support_metadata:
|
1504 |
+
if args.in_json is not None and (args.color_aug or args.random_crop):
|
1505 |
+
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
|
1506 |
+
|
1507 |
+
|
1508 |
+
def load_tokenizer(args: argparse.Namespace):
|
1509 |
+
print("prepare tokenizer")
|
1510 |
+
if args.v2:
|
1511 |
+
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
1512 |
+
else:
|
1513 |
+
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
|
1514 |
+
if args.max_token_length is not None:
|
1515 |
+
print(f"update token length: {args.max_token_length}")
|
1516 |
+
return tokenizer
|
1517 |
+
|
1518 |
+
|
1519 |
+
def prepare_accelerator(args: argparse.Namespace):
|
1520 |
+
if args.logging_dir is None:
|
1521 |
+
log_with = None
|
1522 |
+
logging_dir = None
|
1523 |
+
else:
|
1524 |
+
log_with = "tensorboard"
|
1525 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
1526 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
|
1527 |
+
|
1528 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
|
1529 |
+
log_with=log_with, logging_dir=logging_dir)
|
1530 |
+
|
1531 |
+
# accelerateの互換性問題を解決する
|
1532 |
+
accelerator_0_15 = True
|
1533 |
+
try:
|
1534 |
+
accelerator.unwrap_model("dummy", True)
|
1535 |
+
print("Using accelerator 0.15.0 or above.")
|
1536 |
+
except TypeError:
|
1537 |
+
accelerator_0_15 = False
|
1538 |
+
|
1539 |
+
def unwrap_model(model):
|
1540 |
+
if accelerator_0_15:
|
1541 |
+
return accelerator.unwrap_model(model, True)
|
1542 |
+
return accelerator.unwrap_model(model)
|
1543 |
+
|
1544 |
+
return accelerator, unwrap_model
|
1545 |
+
|
1546 |
+
|
1547 |
+
def prepare_dtype(args: argparse.Namespace):
|
1548 |
+
weight_dtype = torch.float32
|
1549 |
+
if args.mixed_precision == "fp16":
|
1550 |
+
weight_dtype = torch.float16
|
1551 |
+
elif args.mixed_precision == "bf16":
|
1552 |
+
weight_dtype = torch.bfloat16
|
1553 |
+
|
1554 |
+
save_dtype = None
|
1555 |
+
if args.save_precision == "fp16":
|
1556 |
+
save_dtype = torch.float16
|
1557 |
+
elif args.save_precision == "bf16":
|
1558 |
+
save_dtype = torch.bfloat16
|
1559 |
+
elif args.save_precision == "float":
|
1560 |
+
save_dtype = torch.float32
|
1561 |
+
|
1562 |
+
return weight_dtype, save_dtype
|
1563 |
+
|
1564 |
+
|
1565 |
+
def load_target_model(args: argparse.Namespace, weight_dtype):
|
1566 |
+
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
1567 |
+
if load_stable_diffusion_format:
|
1568 |
+
print("load StableDiffusion checkpoint")
|
1569 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
1570 |
+
else:
|
1571 |
+
print("load Diffusers pretrained models")
|
1572 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
1573 |
+
text_encoder = pipe.text_encoder
|
1574 |
+
vae = pipe.vae
|
1575 |
+
unet = pipe.unet
|
1576 |
+
del pipe
|
1577 |
+
|
1578 |
+
# VAEを読み込む
|
1579 |
+
if args.vae is not None:
|
1580 |
+
vae = model_util.load_vae(args.vae, weight_dtype)
|
1581 |
+
print("additional VAE loaded")
|
1582 |
+
|
1583 |
+
return text_encoder, vae, unet, load_stable_diffusion_format
|
1584 |
+
|
1585 |
+
|
1586 |
+
def patch_accelerator_for_fp16_training(accelerator):
|
1587 |
+
org_unscale_grads = accelerator.scaler._unscale_grads_
|
1588 |
+
|
1589 |
+
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
1590 |
+
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
|
1591 |
+
|
1592 |
+
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
|
1593 |
+
|
1594 |
+
|
1595 |
+
def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None):
|
1596 |
+
# with no_token_padding, the length is not max length, return result immediately
|
1597 |
+
if input_ids.size()[-1] != tokenizer.model_max_length:
|
1598 |
+
return text_encoder(input_ids)[0]
|
1599 |
+
|
1600 |
+
b_size = input_ids.size()[0]
|
1601 |
+
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
1602 |
+
|
1603 |
+
if args.clip_skip is None:
|
1604 |
+
encoder_hidden_states = text_encoder(input_ids)[0]
|
1605 |
+
else:
|
1606 |
+
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
1607 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
1608 |
+
if weight_dtype is not None:
|
1609 |
+
# this is required for additional network training
|
1610 |
+
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
1611 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
1612 |
+
|
1613 |
+
# bs*3, 77, 768 or 1024
|
1614 |
+
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
1615 |
+
|
1616 |
+
if args.max_token_length is not None:
|
1617 |
+
if args.v2:
|
1618 |
+
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
1619 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1620 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1621 |
+
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
|
1622 |
+
if i > 0:
|
1623 |
+
for j in range(len(chunk)):
|
1624 |
+
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
1625 |
+
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
1626 |
+
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
1627 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
1628 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1629 |
+
else:
|
1630 |
+
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
1631 |
+
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
1632 |
+
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
1633 |
+
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
1634 |
+
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
1635 |
+
encoder_hidden_states = torch.cat(states_list, dim=1)
|
1636 |
+
|
1637 |
+
return encoder_hidden_states
|
1638 |
+
|
1639 |
+
|
1640 |
+
def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
1641 |
+
model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
1642 |
+
ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt")
|
1643 |
+
return model_name, ckpt_name
|
1644 |
+
|
1645 |
+
|
1646 |
+
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
1647 |
+
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
1648 |
+
if saving:
|
1649 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1650 |
+
save_func()
|
1651 |
+
|
1652 |
+
if args.save_last_n_epochs is not None:
|
1653 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
1654 |
+
remove_old_func(remove_epoch_no)
|
1655 |
+
return saving
|
1656 |
+
|
1657 |
+
|
1658 |
+
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
1659 |
+
epoch_no = epoch + 1
|
1660 |
+
model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no)
|
1661 |
+
|
1662 |
+
if save_stable_diffusion_format:
|
1663 |
+
def save_sd():
|
1664 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1665 |
+
print(f"saving checkpoint: {ckpt_file}")
|
1666 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1667 |
+
src_path, epoch_no, global_step, save_dtype, vae)
|
1668 |
+
|
1669 |
+
def remove_sd(old_epoch_no):
|
1670 |
+
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
1671 |
+
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
1672 |
+
if os.path.exists(old_ckpt_file):
|
1673 |
+
print(f"removing old checkpoint: {old_ckpt_file}")
|
1674 |
+
os.remove(old_ckpt_file)
|
1675 |
+
|
1676 |
+
save_func = save_sd
|
1677 |
+
remove_old_func = remove_sd
|
1678 |
+
else:
|
1679 |
+
def save_du():
|
1680 |
+
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
1681 |
+
print(f"saving model: {out_dir}")
|
1682 |
+
os.makedirs(out_dir, exist_ok=True)
|
1683 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1684 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1685 |
+
|
1686 |
+
def remove_du(old_epoch_no):
|
1687 |
+
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
1688 |
+
if os.path.exists(out_dir_old):
|
1689 |
+
print(f"removing old model: {out_dir_old}")
|
1690 |
+
shutil.rmtree(out_dir_old)
|
1691 |
+
|
1692 |
+
save_func = save_du
|
1693 |
+
remove_old_func = remove_du
|
1694 |
+
|
1695 |
+
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
1696 |
+
if saving and args.save_state:
|
1697 |
+
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
1698 |
+
|
1699 |
+
|
1700 |
+
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
1701 |
+
print("saving state.")
|
1702 |
+
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
1703 |
+
|
1704 |
+
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
1705 |
+
if last_n_epochs is not None:
|
1706 |
+
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
1707 |
+
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
1708 |
+
if os.path.exists(state_dir_old):
|
1709 |
+
print(f"removing old state: {state_dir_old}")
|
1710 |
+
shutil.rmtree(state_dir_old)
|
1711 |
+
|
1712 |
+
|
1713 |
+
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
1714 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1715 |
+
|
1716 |
+
if save_stable_diffusion_format:
|
1717 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1718 |
+
|
1719 |
+
ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt")
|
1720 |
+
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
1721 |
+
|
1722 |
+
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
1723 |
+
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
1724 |
+
src_path, epoch, global_step, save_dtype, vae)
|
1725 |
+
else:
|
1726 |
+
out_dir = os.path.join(args.output_dir, model_name)
|
1727 |
+
os.makedirs(out_dir, exist_ok=True)
|
1728 |
+
|
1729 |
+
print(f"save trained model as Diffusers to {out_dir}")
|
1730 |
+
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
1731 |
+
src_path, vae=vae, use_safetensors=use_safetensors)
|
1732 |
+
|
1733 |
+
|
1734 |
+
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
1735 |
+
print("saving last state.")
|
1736 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
1737 |
+
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
1738 |
+
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
1739 |
+
|
1740 |
+
# endregion
|
1741 |
+
|
1742 |
+
# region 前処理用
|
1743 |
+
|
1744 |
+
|
1745 |
+
class ImageLoadingDataset(torch.utils.data.Dataset):
|
1746 |
+
def __init__(self, image_paths):
|
1747 |
+
self.images = image_paths
|
1748 |
+
|
1749 |
+
def __len__(self):
|
1750 |
+
return len(self.images)
|
1751 |
+
|
1752 |
+
def __getitem__(self, idx):
|
1753 |
+
img_path = self.images[idx]
|
1754 |
+
|
1755 |
+
try:
|
1756 |
+
image = Image.open(img_path).convert("RGB")
|
1757 |
+
# convert to tensor temporarily so dataloader will accept it
|
1758 |
+
tensor_pil = transforms.functional.pil_to_tensor(image)
|
1759 |
+
except Exception as e:
|
1760 |
+
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
1761 |
+
return None
|
1762 |
+
|
1763 |
+
return (tensor_pil, img_path)
|
1764 |
+
|
1765 |
+
|
1766 |
+
# endregion
|
Lora/matous_LORA.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:52e4a8e2fbda7b304374826a7f73061d33f4401cab515a92093f1ff09fd5fc19
|
3 |
+
size 151109011
|