toolkit / utils /chop_blocks.py
k4d3's picture
renames
12d27fb
raw
history blame
9.48 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# SafeTensorsファイルからブロックを切り出す
# Chop blocks from a SafeTensors file.
import argparse
import logging
import re
from collections import defaultdict
from pathlib import Path
import numpy
from safetensors.numpy import safe_open, save_file
logger = logging.getLogger(__name__)
def analyze_lora_layers(
sft_fd: safe_open,
) -> tuple[list[tuple[tuple[str, int], set[str]]], set[str]]:
"""
Analyze the LoRA layers in a SafeTensors file.
Args:
sft_fd (safe_open): An open SafeTensors file.
Returns:
A tuple containing:
- A list of tuples, each containing a (section, index) pair and a set of associated keys.
- A set of pass-through keys (non-LoRA layers).
"""
RE_LORA_NAME = re.compile(
r"lora_unet_((?:input|middle|output|down|mid|up)_blocks?)(?:(?:_(\d+))?_attentions)?_(\d+)_.*"
)
pass_through_keys: set[str] = set()
block2keys: dict[tuple[str, int], set[str]] = defaultdict(set)
for k in sft_fd.keys():
m = RE_LORA_NAME.fullmatch(k.replace("_0_1_transformer_blocks_", "_0_"))
if not m:
pass_through_keys.add(k)
continue
section, idx1, idx2 = m.groups()
if idx1 is None:
idx = idx2
else:
idx = f"{idx1}{idx2}"
block2keys[(section, idx)].add(k)
if not block2keys:
raise ValueError(
"No UNet layers found in the LoRA checkpoint (Maybe not a SDXL model?)"
)
block2keys_sorted = sorted((k, sorted(v)) for k, v in block2keys.items())
for k in pass_through_keys:
if not "te_" in k and "text_" not in k:
logging.warning(
f"key {k} removed but it doesn't look like a text encoder layer"
)
def print_layers(layers):
for layer, params in layers.items():
params = ", ".join(sorted(params))
dbg(f" - {layer:<70}: {params}")
if logger.getEffectiveLevel() <= logging.DEBUG:
dbg = logger.debug
for (section, idx), keys in block2keys_sorted:
layers = groupby_layer(keys)
dbg(f"* {section=} {idx=} keys={len(keys)} layers={len(layers)}")
print_layers(layers)
logger.debug(f" * Pass through: ")
print_layers(groupby_layer(pass_through_keys))
return block2keys_sorted, pass_through_keys
def groupby_layer(
keys, make_empty=set, update=lambda vs, layer_name, param_name: vs.add(param_name)
):
d = defaultdict(make_empty)
for k in keys:
layer, _, param = k.rpartition(".")
vs = d[layer]
update(vs, layer, param)
return d
def print_block_layout(
block2keys: list[tuple[tuple[str, int], set[str]]],
weights: list[float] | None = None,
) -> None:
"""
Print the layout of LoRA blocks, optionally with weights.
Args:
block2keys: A list of tuples, each containing a (section, index) pair and a set of associated keys.
weights: Optional list of weights corresponding to each block.
"""
logger.info("Blocks layout:")
if weights is None:
for i, ((section, idx), keys) in enumerate(block2keys):
logger.info(f"\t[{i:>2d}] {section:>13}.{idx} layers={len(keys):<3}")
section2shortname = {
# SDXL names:
"input_blocks": "INP",
"middle_block": "MID",
"output_blocks": "OUT",
# SD1 names
"down_blocks": "INP",
"mid_block": "MID",
"up_blocks": "OUT",
}
vector_string = ",".join(
f"{section2shortname[section]}{idx:>02}" for (section, idx), _ in block2keys
)
logger.info(f'Vector string format: "1,{vector_string}"')
vector_string = ",".join("0" * len(block2keys))
logger.info(f'Example (drops all blocks): "1,{vector_string}"')
else:
for i, (((section, idx), keys), weight) in enumerate(zip(block2keys, weights)):
if abs(weight) > 1e-6:
if abs(weight - 1) < 1e-6:
weight = 1
w_disp = f"weight={weight}"
else:
w_disp = "removed"
layers = len(
groupby_layer(keys, lambda: None, lambda _layers, _layer, _attr: None)
)
logger.info(
f"\t[{i:>2d}] {section:>13}.{idx} keys={len(keys):<3} layers={layers:<3} {w_disp}"
)
def filter_blocks(sft_fd: safe_open, vector_string: str) -> dict[str, "numpy.ndarray"]:
"""
Filter LoRA blocks based on a vector string.
Args:
sft_fd (safe_open): An open SafeTensors file.
vector_string (str): A string representing weights for each block.
Returns:
A dictionary containing the filtered state dict, or None if an error occurs.
"""
global_weight, *weights_vector = map(float, vector_string.split(","))
block2keys, pass_through_keys = analyze_lora_layers(sft_fd)
if len(weights_vector) != len(block2keys):
logger.error(f"expected {len(block2keys)} weights, got {len(weights_vector)}")
print_block_layout(block2keys)
return None
if logger.getEffectiveLevel() >= logging.INFO:
print_block_layout(block2keys, weights_vector)
state_dict = {}
for weight, ((s, idx), keys) in zip(weights_vector, block2keys):
weight *= global_weight
if abs(weight) < 1e-6:
logger.debug("reject %s:%s (%s)", s, idx, keys[0])
continue
for layer, params in groupby_layer(keys).items():
logger.debug(
"accept %s:%s (%s) weight=%.2f params=%s",
s,
idx,
layer,
weight,
",".join(params),
)
if "alpha" in params:
params.remove("alpha")
key = f"{layer}.alpha"
state_dict[key] = sft_fd.get_tensor(key) * weight
# if 'dora_scale' in params:
# params.remove("dora_scale")
# key = f"{layer}.dora_scale"
# tensor = sft_fd.get_tensor(key)
# if abs(weight - 1.0) > 1e-6:
# tensor -= 1.0
# tensor *= weight
# tensor += 1.0
# state_dict[key] = tensor
for param in params:
key = f"{layer}.{param}"
state_dict[key] = sft_fd.get_tensor(key)
else:
logging.warning("no alpha parameter in layer %s: %r", layer, params)
for param in params:
key = f"{layer}.{param}"
state_dict[key] = sft_fd.get_tensor(key)
logger.info(
"Keeping %d keys from the UNet, %d passing through (text encoders)",
len(state_dict),
len(pass_through_keys),
)
for k in pass_through_keys:
state_dict[k] = sft_fd.get_tensor(k)
return state_dict
def setup_logging(verbosity: int) -> None:
"""
Set up logging based on verbosity level and quiet flag.
Args:
verbosity (int): The verbosity level (0-2).
quiet (bool): If True, suppress all output except errors.
"""
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG]
log_level = log_levels[max(0, min(verbosity, 2))]
logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")
def main() -> None:
"""
Main function to handle CLI arguments and execute the appropriate actions.
"""
parser = argparse.ArgumentParser(
description="Analyze and filter LoRA layers in SafeTensors files."
)
parser.add_argument("input_file", type=Path, help="Input SafeTensors file")
parser.add_argument(
"vector_string", nargs="?", help="Vector string for filtering blocks"
)
parser.add_argument("-o", "--output", type=Path, help="Output file path")
parser.add_argument(
"-v",
"--verbose",
action="count",
default=1,
help="Increase verbosity (can be repeated)",
)
parser.add_argument(
"-q",
"--quiet",
action="count",
default=0,
help="Suppress all output except errors",
)
args = parser.parse_args()
setup_logging(args.verbose - args.quiet)
with safe_open(args.input_file, framework="np") as sft_fd:
if args.vector_string:
# Filter blocks and save the result
filtered_state_dict = filter_blocks(sft_fd, args.vector_string)
if filtered_state_dict is None:
logging.error("No lyaers in output!")
exit(1)
# Determine output path
output_path = args.output or args.input_file.with_stem(
f"{args.input_file.stem}-chop"
)
metadata = sft_fd.metadata()
metadata["block_vector_string"] = args.vector_string
save_file(filtered_state_dict, output_path, metadata=metadata)
logging.info(f"Filtered LoRA saved to {output_path}")
else:
# Analyze LoRA layers
block2keys, pass_through_keys = analyze_lora_layers(sft_fd)
print_block_layout(block2keys)
logging.info(f"Pass through layers: {len(pass_through_keys)}")
if __name__ == "__main__":
main()