|
import os |
|
from tqdm import tqdm |
|
import argparse |
|
from collections import OrderedDict |
|
|
|
parser = argparse.ArgumentParser(description="Extract LoRA from Flex") |
|
parser.add_argument("--base", type=str, default="ostris/Flex.1-alpha", help="Base model path") |
|
parser.add_argument("--tuned", type=str, required=True, help="Tuned model path") |
|
parser.add_argument("--output", type=str, required=True, help="Output path for lora") |
|
parser.add_argument("--rank", type=int, default=32, help="LoRA rank for extraction") |
|
parser.add_argument("--gpu", type=int, default=0, help="GPU to process extraction") |
|
parser.add_argument("--full", action="store_true", help="Do a full transformer extraction, not just transformer blocks") |
|
|
|
args = parser.parse_args() |
|
|
|
if True: |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) |
|
import torch |
|
from safetensors.torch import load_file, save_file |
|
from lycoris.utils import extract_linear, extract_conv, make_sparse |
|
from diffusers import FluxTransformer2DModel |
|
|
|
base = args.base |
|
tuned = args.tuned |
|
output_path = args.output |
|
dim = args.rank |
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
|
state_dict_base = {} |
|
state_dict_tuned = {} |
|
|
|
output_dict = {} |
|
|
|
@torch.no_grad() |
|
def extract_diff( |
|
base_unet, |
|
db_unet, |
|
mode="fixed", |
|
linear_mode_param=0, |
|
conv_mode_param=0, |
|
extract_device="cpu", |
|
use_bias=False, |
|
sparsity=0.98, |
|
|
|
small_conv=False, |
|
): |
|
UNET_TARGET_REPLACE_MODULE = [ |
|
"Linear", |
|
"Conv2d", |
|
"LayerNorm", |
|
"GroupNorm", |
|
"GroupNorm32", |
|
"LoRACompatibleLinear", |
|
"LoRACompatibleConv" |
|
] |
|
LORA_PREFIX_UNET = "transformer" |
|
|
|
def make_state_dict( |
|
prefix, |
|
root_module: torch.nn.Module, |
|
target_module: torch.nn.Module, |
|
target_replace_modules, |
|
): |
|
loras = {} |
|
temp = {} |
|
|
|
for name, module in root_module.named_modules(): |
|
if module.__class__.__name__ in target_replace_modules: |
|
temp[name] = module |
|
|
|
for name, module in tqdm( |
|
list((n, m) for n, m in target_module.named_modules() if n in temp) |
|
): |
|
weights = temp[name] |
|
lora_name = prefix + "." + name |
|
|
|
layer = module.__class__.__name__ |
|
if 'transformer_blocks' not in lora_name and not args.full: |
|
continue |
|
|
|
if layer in { |
|
"Linear", |
|
"Conv2d", |
|
"LayerNorm", |
|
"GroupNorm", |
|
"GroupNorm32", |
|
"Embedding", |
|
"LoRACompatibleLinear", |
|
"LoRACompatibleConv" |
|
}: |
|
root_weight = module.weight |
|
try: |
|
if torch.allclose(root_weight, weights.weight): |
|
continue |
|
except: |
|
continue |
|
else: |
|
continue |
|
module = module.to(extract_device, torch.float32) |
|
weights = weights.to(extract_device, torch.float32) |
|
|
|
if mode == "full": |
|
decompose_mode = "full" |
|
elif layer == "Linear": |
|
weight, decompose_mode = extract_linear( |
|
(root_weight - weights.weight), |
|
mode, |
|
linear_mode_param, |
|
device=extract_device, |
|
) |
|
if decompose_mode == "low rank": |
|
extract_a, extract_b, diff = weight |
|
elif layer == "Conv2d": |
|
is_linear = root_weight.shape[2] == 1 and root_weight.shape[3] == 1 |
|
weight, decompose_mode = extract_conv( |
|
(root_weight - weights.weight), |
|
mode, |
|
linear_mode_param if is_linear else conv_mode_param, |
|
device=extract_device, |
|
) |
|
if decompose_mode == "low rank": |
|
extract_a, extract_b, diff = weight |
|
if small_conv and not is_linear and decompose_mode == "low rank": |
|
dim = extract_a.size(0) |
|
(extract_c, extract_a, _), _ = extract_conv( |
|
extract_a.transpose(0, 1), |
|
"fixed", |
|
dim, |
|
extract_device, |
|
True, |
|
) |
|
extract_a = extract_a.transpose(0, 1) |
|
extract_c = extract_c.transpose(0, 1) |
|
loras[f"{lora_name}.lora_mid.weight"] = ( |
|
extract_c.detach().cpu().contiguous().half() |
|
) |
|
diff = ( |
|
( |
|
root_weight |
|
- torch.einsum( |
|
"i j k l, j r, p i -> p r k l", |
|
extract_c, |
|
extract_a.flatten(1, -1), |
|
extract_b.flatten(1, -1), |
|
) |
|
) |
|
.detach() |
|
.cpu() |
|
.contiguous() |
|
) |
|
del extract_c |
|
else: |
|
module = module.to("cpu") |
|
weights = weights.to("cpu") |
|
continue |
|
|
|
if decompose_mode == "low rank": |
|
loras[f"{lora_name}.lora_A.weight"] = ( |
|
extract_a.detach().cpu().contiguous().half() |
|
) |
|
loras[f"{lora_name}.lora_B.weight"] = ( |
|
extract_b.detach().cpu().contiguous().half() |
|
) |
|
|
|
if use_bias: |
|
diff = diff.detach().cpu().reshape(extract_b.size(0), -1) |
|
sparse_diff = make_sparse(diff, sparsity).to_sparse().coalesce() |
|
|
|
indices = sparse_diff.indices().to(torch.int16) |
|
values = sparse_diff.values().half() |
|
loras[f"{lora_name}.bias_indices"] = indices |
|
loras[f"{lora_name}.bias_values"] = values |
|
loras[f"{lora_name}.bias_size"] = torch.tensor(diff.shape).to( |
|
torch.int16 |
|
) |
|
del extract_a, extract_b, diff |
|
elif decompose_mode == "full": |
|
if "Norm" in layer: |
|
w_key = "w_norm" |
|
b_key = "b_norm" |
|
else: |
|
w_key = "diff" |
|
b_key = "diff_b" |
|
weight_diff = module.weight - weights.weight |
|
loras[f"{lora_name}.{w_key}"] = ( |
|
weight_diff.detach().cpu().contiguous().half() |
|
) |
|
if getattr(weights, "bias", None) is not None: |
|
bias_diff = module.bias - weights.bias |
|
loras[f"{lora_name}.{b_key}"] = ( |
|
bias_diff.detach().cpu().contiguous().half() |
|
) |
|
else: |
|
raise NotImplementedError |
|
module = module.to("cpu", torch.bfloat16) |
|
weights = weights.to("cpu", torch.bfloat16) |
|
return loras |
|
|
|
all_loras = {} |
|
|
|
all_loras |= make_state_dict( |
|
LORA_PREFIX_UNET, |
|
base_unet, |
|
db_unet, |
|
UNET_TARGET_REPLACE_MODULE, |
|
) |
|
del base_unet, db_unet |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
all_lora_name = set() |
|
for k in all_loras: |
|
lora_name, weight = k.rsplit(".", 1) |
|
all_lora_name.add(lora_name) |
|
print(len(all_lora_name)) |
|
return all_loras |
|
|
|
|
|
|
|
print("Loading Base") |
|
base_model = FluxTransformer2DModel.from_pretrained(base, subfolder="transformer", torch_dtype=torch.bfloat16) |
|
|
|
print("Loading Tuned") |
|
tuned_model = FluxTransformer2DModel.from_pretrained(tuned, subfolder="transformer", torch_dtype=torch.bfloat16) |
|
|
|
output_dict = extract_diff( |
|
base_model, |
|
tuned_model, |
|
mode="fixed", |
|
linear_mode_param=dim, |
|
conv_mode_param=dim, |
|
extract_device="cuda", |
|
use_bias=False, |
|
sparsity=0.98, |
|
small_conv=False, |
|
) |
|
|
|
meta = OrderedDict() |
|
meta['format'] = 'pt' |
|
|
|
save_file(output_dict, output_path, metadata=meta) |
|
|
|
print("Done") |
|
|