only scale the alphas, dora only for boolean weight
Browse filesSigned-off-by: Balazs Horvath <acsipont@gmail.com>
- chop_blocks +83 -17
chop_blocks
CHANGED
@@ -6,6 +6,7 @@ import logging
|
|
6 |
import re
|
7 |
from collections import defaultdict
|
8 |
from pathlib import Path
|
|
|
9 |
|
10 |
from safetensors.numpy import safe_open, save_file
|
11 |
|
@@ -34,7 +35,7 @@ def analyze_lora_layers(
|
|
34 |
block2keys: dict[tuple[str, int], set[str]] = defaultdict(set)
|
35 |
|
36 |
for k in sft_fd.keys():
|
37 |
-
m = RE_LORA_NAME.fullmatch(k)
|
38 |
if not m:
|
39 |
pass_through_keys.add(k)
|
40 |
continue
|
@@ -50,10 +51,42 @@ def analyze_lora_layers(
|
|
50 |
raise ValueError(
|
51 |
"No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)"
|
52 |
)
|
53 |
-
block2keys_sorted = sorted(block2keys.items())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
return block2keys_sorted, pass_through_keys
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def print_block_layout(
|
58 |
block2keys: list[tuple[tuple[str, int], set[str]]],
|
59 |
weights: list[float] | None = None,
|
@@ -67,8 +100,8 @@ def print_block_layout(
|
|
67 |
"""
|
68 |
logger.info("Blocks layout:")
|
69 |
if weights is None:
|
70 |
-
for i, ((section, idx),
|
71 |
-
logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(
|
72 |
section2shortname = {
|
73 |
# SDXL names:
|
74 |
"input_blocks": "INP",
|
@@ -86,17 +119,20 @@ def print_block_layout(
|
|
86 |
vector_string = ",".join("0" * len(block2keys))
|
87 |
logger.info(f'Example (drops all blocks): "1,{vector_string}"')
|
88 |
else:
|
89 |
-
for i, (((section, idx),
|
90 |
if abs(weight) > 1e-6:
|
91 |
if abs(weight - 1) < 1e-6:
|
92 |
weight = 1
|
93 |
-
|
94 |
-
f"\t[{i:>2d}] {section:>13}.{idx} layers={len(v):<3} weight={weight}"
|
95 |
-
)
|
96 |
else:
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]:
|
@@ -122,15 +158,44 @@ def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.nda
|
|
122 |
print_block_layout(block2keys, weights_vector)
|
123 |
|
124 |
state_dict = {}
|
125 |
-
for weight, (
|
126 |
weight *= global_weight
|
127 |
if abs(weight) < 1e-6:
|
|
|
128 |
continue
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
logger.info(
|
136 |
"Keeping %d keys from the UNet, %d passing through (text encoders)",
|
@@ -190,6 +255,7 @@ def main() -> None:
|
|
190 |
# Filter blocks and save the result
|
191 |
filtered_state_dict = filter_blocks(sft_fd, args.vector_string)
|
192 |
if filtered_state_dict is None:
|
|
|
193 |
exit(1)
|
194 |
|
195 |
# Determine output path
|
|
|
6 |
import re
|
7 |
from collections import defaultdict
|
8 |
from pathlib import Path
|
9 |
+
import numpy as np
|
10 |
|
11 |
from safetensors.numpy import safe_open, save_file
|
12 |
|
|
|
35 |
block2keys: dict[tuple[str, int], set[str]] = defaultdict(set)
|
36 |
|
37 |
for k in sft_fd.keys():
|
38 |
+
m = RE_LORA_NAME.fullmatch(k.replace("_0_1_transformer_blocks_", "_0_"))
|
39 |
if not m:
|
40 |
pass_through_keys.add(k)
|
41 |
continue
|
|
|
51 |
raise ValueError(
|
52 |
"No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)"
|
53 |
)
|
54 |
+
block2keys_sorted = sorted((k, sorted(v)) for k, v in block2keys.items())
|
55 |
+
|
56 |
+
for k in pass_through_keys:
|
57 |
+
if not "te_" in k and "text_" not in k:
|
58 |
+
logging.warning(
|
59 |
+
f"key {k} removed but it doesn't look like a text encoder layer"
|
60 |
+
)
|
61 |
+
|
62 |
+
def print_layers(layers):
|
63 |
+
for layer, params in layers.items():
|
64 |
+
params = ", ".join(sorted(params))
|
65 |
+
dbg(f" - {layer:<70}: {params}")
|
66 |
+
|
67 |
+
if logger.getEffectiveLevel() <= logging.DEBUG:
|
68 |
+
dbg = logger.debug
|
69 |
+
for (section, idx), keys in block2keys_sorted:
|
70 |
+
layers = groupby_layer(keys)
|
71 |
+
dbg(f"* {section=} {idx=} keys={len(keys)} layers={len(layers)}")
|
72 |
+
print_layers(layers)
|
73 |
+
|
74 |
+
logger.debug(f" * Pass through: ")
|
75 |
+
print_layers(groupby_layer(pass_through_keys))
|
76 |
return block2keys_sorted, pass_through_keys
|
77 |
|
78 |
|
79 |
+
def groupby_layer(
|
80 |
+
keys, make_empty=set, update=lambda vs, layer_name, param_name: vs.add(param_name)
|
81 |
+
):
|
82 |
+
d = defaultdict(make_empty)
|
83 |
+
for k in keys:
|
84 |
+
layer, _, param = k.rpartition(".")
|
85 |
+
vs = d[layer]
|
86 |
+
update(vs, layer, param)
|
87 |
+
return d
|
88 |
+
|
89 |
+
|
90 |
def print_block_layout(
|
91 |
block2keys: list[tuple[tuple[str, int], set[str]]],
|
92 |
weights: list[float] | None = None,
|
|
|
100 |
"""
|
101 |
logger.info("Blocks layout:")
|
102 |
if weights is None:
|
103 |
+
for i, ((section, idx), keys) in enumerate(block2keys):
|
104 |
+
logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(keys):<3}")
|
105 |
section2shortname = {
|
106 |
# SDXL names:
|
107 |
"input_blocks": "INP",
|
|
|
119 |
vector_string = ",".join("0" * len(block2keys))
|
120 |
logger.info(f'Example (drops all blocks): "1,{vector_string}"')
|
121 |
else:
|
122 |
+
for i, (((section, idx), keys), weight) in enumerate(zip(block2keys, weights)):
|
123 |
if abs(weight) > 1e-6:
|
124 |
if abs(weight - 1) < 1e-6:
|
125 |
weight = 1
|
126 |
+
w_disp = f"weight={weight}"
|
|
|
|
|
127 |
else:
|
128 |
+
w_disp = "removed"
|
129 |
+
|
130 |
+
layers = len(
|
131 |
+
groupby_layer(keys, lambda: None, lambda _layers, _layer, _attr: None)
|
132 |
+
)
|
133 |
+
logger.info(
|
134 |
+
f"\t[{i:>2d}] {section:>13}.{idx} keys={len(keys):<3} layers={layers:<3} {w_disp}"
|
135 |
+
)
|
136 |
|
137 |
|
138 |
def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]:
|
|
|
158 |
print_block_layout(block2keys, weights_vector)
|
159 |
|
160 |
state_dict = {}
|
161 |
+
for weight, ((s, idx), keys) in zip(weights_vector, block2keys):
|
162 |
weight *= global_weight
|
163 |
if abs(weight) < 1e-6:
|
164 |
+
logger.debug("reject %s:%s (%s)", s, idx, keys[0])
|
165 |
continue
|
166 |
+
|
167 |
+
for layer, params in groupby_layer(keys).items():
|
168 |
+
logger.debug(
|
169 |
+
"accept %s:%s (%s) weight=%.2f params=%s",
|
170 |
+
s,
|
171 |
+
idx,
|
172 |
+
layer,
|
173 |
+
weight,
|
174 |
+
",".join(params),
|
175 |
+
)
|
176 |
+
|
177 |
+
if "alpha" in params:
|
178 |
+
params.remove("alpha")
|
179 |
+
key = f"{layer}.alpha"
|
180 |
+
state_dict[key] = sft_fd.get_tensor(key) * weight
|
181 |
+
# if 'dora_scale' in params:
|
182 |
+
# params.remove("dora_scale")
|
183 |
+
# key = f"{layer}.dora_scale"
|
184 |
+
# tensor = sft_fd.get_tensor(key)
|
185 |
+
# if abs(weight - 1.0) > 1e-6:
|
186 |
+
# tensor -= 1.0
|
187 |
+
# tensor *= weight
|
188 |
+
# tensor += 1.0
|
189 |
+
# state_dict[key] = tensor
|
190 |
+
|
191 |
+
for param in params:
|
192 |
+
key = f"{layer}.{param}"
|
193 |
+
state_dict[key] = sft_fd.get_tensor(key)
|
194 |
+
else:
|
195 |
+
logging.warning("no alpha parameter in layer %s: %r", layer, params)
|
196 |
+
for param in params:
|
197 |
+
key = f"{layer}.{param}"
|
198 |
+
state_dict[key] = sft_fd.get_tensor(key)
|
199 |
|
200 |
logger.info(
|
201 |
"Keeping %d keys from the UNet, %d passing through (text encoders)",
|
|
|
255 |
# Filter blocks and save the result
|
256 |
filtered_state_dict = filter_blocks(sft_fd, args.vector_string)
|
257 |
if filtered_state_dict is None:
|
258 |
+
logging.error("No lyaers in output!")
|
259 |
exit(1)
|
260 |
|
261 |
# Determine output path
|