k4d3 commited on
Commit
2bb76e3
·
1 Parent(s): cd52e87

only scale the alphas, dora only for boolean weight

Browse files

Signed-off-by: Balazs Horvath <acsipont@gmail.com>

Files changed (1) hide show
  1. chop_blocks +83 -17
chop_blocks CHANGED
@@ -6,6 +6,7 @@ import logging
6
  import re
7
  from collections import defaultdict
8
  from pathlib import Path
 
9
 
10
  from safetensors.numpy import safe_open, save_file
11
 
@@ -34,7 +35,7 @@ def analyze_lora_layers(
34
  block2keys: dict[tuple[str, int], set[str]] = defaultdict(set)
35
 
36
  for k in sft_fd.keys():
37
- m = RE_LORA_NAME.fullmatch(k)
38
  if not m:
39
  pass_through_keys.add(k)
40
  continue
@@ -50,10 +51,42 @@ def analyze_lora_layers(
50
  raise ValueError(
51
  "No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)"
52
  )
53
- block2keys_sorted = sorted(block2keys.items())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  return block2keys_sorted, pass_through_keys
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  def print_block_layout(
58
  block2keys: list[tuple[tuple[str, int], set[str]]],
59
  weights: list[float] | None = None,
@@ -67,8 +100,8 @@ def print_block_layout(
67
  """
68
  logger.info("Blocks layout:")
69
  if weights is None:
70
- for i, ((section, idx), v) in enumerate(block2keys):
71
- logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(v):<3}")
72
  section2shortname = {
73
  # SDXL names:
74
  "input_blocks": "INP",
@@ -86,17 +119,20 @@ def print_block_layout(
86
  vector_string = ",".join("0" * len(block2keys))
87
  logger.info(f'Example (drops all blocks): "1,{vector_string}"')
88
  else:
89
- for i, (((section, idx), v), weight) in enumerate(zip(block2keys, weights)):
90
  if abs(weight) > 1e-6:
91
  if abs(weight - 1) < 1e-6:
92
  weight = 1
93
- logger.info(
94
- f"\t[{i:>2d}] {section:>13}.{idx} layers={len(v):<3} weight={weight}"
95
- )
96
  else:
97
- logger.info(
98
- f"\t[{i:>2d}] {section:>13}.{idx} layers={len(v):<3} (removed)"
99
- )
 
 
 
 
 
100
 
101
 
102
  def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]:
@@ -122,15 +158,44 @@ def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.nda
122
  print_block_layout(block2keys, weights_vector)
123
 
124
  state_dict = {}
125
- for weight, (_, keys) in zip(weights_vector, block2keys):
126
  weight *= global_weight
127
  if abs(weight) < 1e-6:
 
128
  continue
129
- for k in keys:
130
- tensor = sft_fd.get_tensor(k)
131
- if abs(weight - 1.0) > 1e-6:
132
- tensor *= weight
133
- state_dict[k] = tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  logger.info(
136
  "Keeping %d keys from the UNet, %d passing through (text encoders)",
@@ -190,6 +255,7 @@ def main() -> None:
190
  # Filter blocks and save the result
191
  filtered_state_dict = filter_blocks(sft_fd, args.vector_string)
192
  if filtered_state_dict is None:
 
193
  exit(1)
194
 
195
  # Determine output path
 
6
  import re
7
  from collections import defaultdict
8
  from pathlib import Path
9
+ import numpy as np
10
 
11
  from safetensors.numpy import safe_open, save_file
12
 
 
35
  block2keys: dict[tuple[str, int], set[str]] = defaultdict(set)
36
 
37
  for k in sft_fd.keys():
38
+ m = RE_LORA_NAME.fullmatch(k.replace("_0_1_transformer_blocks_", "_0_"))
39
  if not m:
40
  pass_through_keys.add(k)
41
  continue
 
51
  raise ValueError(
52
  "No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)"
53
  )
54
+ block2keys_sorted = sorted((k, sorted(v)) for k, v in block2keys.items())
55
+
56
+ for k in pass_through_keys:
57
+ if not "te_" in k and "text_" not in k:
58
+ logging.warning(
59
+ f"key {k} removed but it doesn't look like a text encoder layer"
60
+ )
61
+
62
+ def print_layers(layers):
63
+ for layer, params in layers.items():
64
+ params = ", ".join(sorted(params))
65
+ dbg(f" - {layer:<70}: {params}")
66
+
67
+ if logger.getEffectiveLevel() <= logging.DEBUG:
68
+ dbg = logger.debug
69
+ for (section, idx), keys in block2keys_sorted:
70
+ layers = groupby_layer(keys)
71
+ dbg(f"* {section=} {idx=} keys={len(keys)} layers={len(layers)}")
72
+ print_layers(layers)
73
+
74
+ logger.debug(f" * Pass through: ")
75
+ print_layers(groupby_layer(pass_through_keys))
76
  return block2keys_sorted, pass_through_keys
77
 
78
 
79
+ def groupby_layer(
80
+ keys, make_empty=set, update=lambda vs, layer_name, param_name: vs.add(param_name)
81
+ ):
82
+ d = defaultdict(make_empty)
83
+ for k in keys:
84
+ layer, _, param = k.rpartition(".")
85
+ vs = d[layer]
86
+ update(vs, layer, param)
87
+ return d
88
+
89
+
90
  def print_block_layout(
91
  block2keys: list[tuple[tuple[str, int], set[str]]],
92
  weights: list[float] | None = None,
 
100
  """
101
  logger.info("Blocks layout:")
102
  if weights is None:
103
+ for i, ((section, idx), keys) in enumerate(block2keys):
104
+ logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(keys):<3}")
105
  section2shortname = {
106
  # SDXL names:
107
  "input_blocks": "INP",
 
119
  vector_string = ",".join("0" * len(block2keys))
120
  logger.info(f'Example (drops all blocks): "1,{vector_string}"')
121
  else:
122
+ for i, (((section, idx), keys), weight) in enumerate(zip(block2keys, weights)):
123
  if abs(weight) > 1e-6:
124
  if abs(weight - 1) < 1e-6:
125
  weight = 1
126
+ w_disp = f"weight={weight}"
 
 
127
  else:
128
+ w_disp = "removed"
129
+
130
+ layers = len(
131
+ groupby_layer(keys, lambda: None, lambda _layers, _layer, _attr: None)
132
+ )
133
+ logger.info(
134
+ f"\t[{i:>2d}] {section:>13}.{idx} keys={len(keys):<3} layers={layers:<3} {w_disp}"
135
+ )
136
 
137
 
138
  def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]:
 
158
  print_block_layout(block2keys, weights_vector)
159
 
160
  state_dict = {}
161
+ for weight, ((s, idx), keys) in zip(weights_vector, block2keys):
162
  weight *= global_weight
163
  if abs(weight) < 1e-6:
164
+ logger.debug("reject %s:%s (%s)", s, idx, keys[0])
165
  continue
166
+
167
+ for layer, params in groupby_layer(keys).items():
168
+ logger.debug(
169
+ "accept %s:%s (%s) weight=%.2f params=%s",
170
+ s,
171
+ idx,
172
+ layer,
173
+ weight,
174
+ ",".join(params),
175
+ )
176
+
177
+ if "alpha" in params:
178
+ params.remove("alpha")
179
+ key = f"{layer}.alpha"
180
+ state_dict[key] = sft_fd.get_tensor(key) * weight
181
+ # if 'dora_scale' in params:
182
+ # params.remove("dora_scale")
183
+ # key = f"{layer}.dora_scale"
184
+ # tensor = sft_fd.get_tensor(key)
185
+ # if abs(weight - 1.0) > 1e-6:
186
+ # tensor -= 1.0
187
+ # tensor *= weight
188
+ # tensor += 1.0
189
+ # state_dict[key] = tensor
190
+
191
+ for param in params:
192
+ key = f"{layer}.{param}"
193
+ state_dict[key] = sft_fd.get_tensor(key)
194
+ else:
195
+ logging.warning("no alpha parameter in layer %s: %r", layer, params)
196
+ for param in params:
197
+ key = f"{layer}.{param}"
198
+ state_dict[key] = sft_fd.get_tensor(key)
199
 
200
  logger.info(
201
  "Keeping %d keys from the UNet, %d passing through (text encoders)",
 
255
  # Filter blocks and save the result
256
  filtered_state_dict = filter_blocks(sft_fd, args.vector_string)
257
  if filtered_state_dict is None:
258
+ logging.error("No lyaers in output!")
259
  exit(1)
260
 
261
  # Determine output path