ortha / weight_fusion.py
ujin-song's picture
upload .py files at root dir
b7e867a verified
import argparse
import copy
import itertools
import json
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, StableDiffusionPipeline
from tqdm import tqdm
from mixofshow.models.edlora import revise_edlora_unet_attention_forward
from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
from mixofshow.utils.util import set_logger
TEMPLATE_SIMPLE = 'photo of a {}'
def chunk_compute_mse(K_target, V_target, W, device, chunk_size=5000):
num_chunks = (K_target.size(0) + chunk_size - 1) // chunk_size
loss = 0
for i in range(num_chunks):
# Extract the current chunk
start_idx = i * chunk_size
end_idx = min(start_idx + chunk_size, K_target.size(0))
loss += F.mse_loss(
F.linear(K_target[start_idx:end_idx].to(device), W),
V_target[start_idx:end_idx].to(device)) * (end_idx - start_idx)
loss /= K_target.size(0)
return loss
def update_quasi_newton(K_target, V_target, W, iters, device):
'''
Args:
K: torch.Tensor, size [n_samples, n_features]
V: torch.Tensor, size [n_samples, n_targets]
K_target: torch.Tensor, size [n_constraints, n_features]
V_target: torch.Tensor, size [n_constraints, n_targets]
W: torch.Tensor, size [n_targets, n_features]
Returns:
Wnew: torch.Tensor, size [n_targets, n_features]
'''
W = W.detach()
V_target = V_target.detach()
K_target = K_target.detach()
W.requires_grad = True
K_target.requires_grad = False
V_target.requires_grad = False
best_loss = np.Inf
best_W = None
def closure():
nonlocal best_W, best_loss
optimizer.zero_grad()
if len(W.shape) == 4:
loss = F.mse_loss(F.conv2d(K_target.to(device), W),
V_target.to(device))
else:
loss = chunk_compute_mse(K_target, V_target, W, device)
if loss < best_loss:
best_loss = loss
best_W = W.clone().cpu()
loss.backward()
return loss
optimizer = optim.LBFGS([W],
lr=1,
max_iter=iters,
history_size=25,
line_search_fn='strong_wolfe',
tolerance_grad=1e-16,
tolerance_change=1e-16)
optimizer.step(closure)
with torch.no_grad():
if len(W.shape) == 4:
loss = torch.norm(
F.conv2d(K_target.to(torch.float32), best_W.to(torch.float32)) - V_target.to(torch.float32), 2, dim=1)
else:
loss = torch.norm(
F.linear(K_target.to(torch.float32), best_W.to(torch.float32)) - V_target.to(torch.float32), 2, dim=1)
logging.info('new_concept loss: %e' % loss.mean().item())
return best_W
def merge_lora_into_weight(original_state_dict, lora_state_dict, modification_layer_names, model_type, alpha, device):
def get_lora_down_name(original_layer_name):
if model_type == 'text_encoder':
lora_down_name = original_layer_name.replace('q_proj.weight', 'q_proj.lora_down.weight') \
.replace('k_proj.weight', 'k_proj.lora_down.weight') \
.replace('v_proj.weight', 'v_proj.lora_down.weight') \
.replace('out_proj.weight', 'out_proj.lora_down.weight') \
.replace('fc1.weight', 'fc1.lora_down.weight') \
.replace('fc2.weight', 'fc2.lora_down.weight')
else:
lora_down_name = k.replace('to_q.weight', 'to_q.lora_down.weight') \
.replace('to_k.weight', 'to_k.lora_down.weight') \
.replace('to_v.weight', 'to_v.lora_down.weight') \
.replace('to_out.0.weight', 'to_out.0.lora_down.weight') \
.replace('ff.net.0.proj.weight', 'ff.net.0.proj.lora_down.weight') \
.replace('ff.net.2.weight', 'ff.net.2.lora_down.weight') \
.replace('proj_out.weight', 'proj_out.lora_down.weight') \
.replace('proj_in.weight', 'proj_in.lora_down.weight')
return lora_down_name
assert model_type in ['unet', 'text_encoder']
new_state_dict = copy.deepcopy(original_state_dict)
load_cnt = 0
for k in modification_layer_names:
lora_down_name = get_lora_down_name(k)
lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
if lora_up_name in lora_state_dict:
load_cnt += 1
original_params = new_state_dict[k]
lora_down_params = lora_state_dict[lora_down_name].to(device)
lora_up_params = lora_state_dict[lora_up_name].to(device)
if len(original_params.shape) == 4:
lora_param = lora_up_params.squeeze(
) @ lora_down_params.squeeze()
lora_param = lora_param.unsqueeze(-1).unsqueeze(-1)
else:
lora_param = lora_up_params @ lora_down_params
merge_params = original_params + alpha * lora_param
new_state_dict[k] = merge_params
logging.info(f'load {load_cnt} LoRAs of {model_type}')
return new_state_dict
module_io_recoder = {}
record_feature = False # remember to set record feature
def get_hooker(module_name):
def hook(module, feature_in, feature_out):
if module_name not in module_io_recoder:
module_io_recoder[module_name] = {'input': [], 'output': []}
if record_feature:
module_io_recoder[module_name]['input'].append(feature_in[0].cpu())
if module.bias is not None:
if len(feature_out.shape) == 4:
bias = module.bias.unsqueeze(-1).unsqueeze(-1)
else:
bias = module.bias
module_io_recoder[module_name]['output'].append(
(feature_out - bias).cpu()) # remove bias
else:
module_io_recoder[module_name]['output'].append(
feature_out.cpu())
return hook
def init_stable_diffusion(pretrained_model_path, device):
# step1: get w0 parameters
model_id = pretrained_model_path
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
train_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder='scheduler')
test_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder='scheduler')
pipe.safety_checker = None
pipe.scheduler = test_scheduler
return pipe, train_scheduler, test_scheduler
@torch.no_grad()
def get_text_feature(prompts, tokenizer, text_encoder, device, return_type='category_embedding'):
text_features = []
if return_type == 'category_embedding':
for text in prompts:
tokens = tokenizer(
text,
truncation=True,
max_length=tokenizer.model_max_length,
return_length=True,
return_overflowing_tokens=False,
padding='do_not_pad',
).input_ids
new_token_position = torch.where(torch.tensor(tokens) >= 49407)[0]
# >40497 not include end token | >=40497 include end token
concept_feature = text_encoder(
torch.LongTensor(tokens).reshape(
1, -1).to(device))[0][:,
new_token_position].reshape(-1, 768)
text_features.append(concept_feature)
return torch.cat(text_features, 0).float()
elif return_type == 'full_embedding':
text_input = tokenizer(prompts,
padding='max_length',
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
return text_embeddings
else:
raise NotImplementedError
def merge_new_concepts_(embedding_list, concept_list, tokenizer, text_encoder):
def add_new_concept(concept_name, embedding):
new_token_names = [
f'<new{start_idx + layer_id}>'
for layer_id in range(NUM_CROSS_ATTENTION_LAYERS)
]
num_added_tokens = tokenizer.add_tokens(new_token_names)
assert num_added_tokens == NUM_CROSS_ATTENTION_LAYERS
new_token_ids = [
tokenizer.convert_tokens_to_ids(token_name)
for token_name in new_token_names
]
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
token_embeds[new_token_ids] = token_embeds[new_token_ids].copy_(
embedding[concept_name])
embedding_features.update({concept_name: embedding[concept_name]})
logging.info(
f'concept {concept_name} is bind with token_id: [{min(new_token_ids)}, {max(new_token_ids)}]'
)
return start_idx + NUM_CROSS_ATTENTION_LAYERS, new_token_ids, new_token_names
embedding_features = {}
new_concept_cfg = {}
start_idx = 0
NUM_CROSS_ATTENTION_LAYERS = 16
for idx, (embedding,
concept) in enumerate(zip(embedding_list, concept_list)):
concept_names = concept['concept_name'].split(' ')
for concept_name in concept_names:
if not concept_name.startswith('<'):
continue
else:
assert concept_name in embedding, 'check the config, the provide concept name is not in the lora model'
start_idx, new_token_ids, new_token_names = add_new_concept(
concept_name, embedding)
new_concept_cfg.update({
concept_name: {
'concept_token_ids': new_token_ids,
'concept_token_names': new_token_names
}
})
return embedding_features, new_concept_cfg
def parse_new_concepts(concept_cfg):
with open(concept_cfg, 'r') as f:
concept_list = json.load(f)
model_paths = [concept['lora_path'] for concept in concept_list]
embedding_list = []
text_encoder_list = []
unet_crosskv_list = []
unet_spatial_attn_list = []
for model_path in model_paths:
model = torch.load(model_path)['params']
if 'new_concept_embedding' in model and len(
model['new_concept_embedding']) != 0:
embedding_list.append(model['new_concept_embedding'])
else:
embedding_list.append(None)
if 'text_encoder' in model and len(model['text_encoder']) != 0:
text_encoder_list.append(model['text_encoder'])
else:
text_encoder_list.append(None)
if 'unet' in model and len(model['unet']) != 0:
crosskv_matches = ['attn2.to_k.lora', 'attn2.to_v.lora']
crosskv_dict = {
k: v
for k, v in model['unet'].items()
if any([x in k for x in crosskv_matches])
}
if len(crosskv_dict) != 0:
unet_crosskv_list.append(crosskv_dict)
else:
unet_crosskv_list.append(None)
spatial_attn_dict = {
k: v
for k, v in model['unet'].items()
if all([x not in k for x in crosskv_matches])
}
if len(spatial_attn_dict) != 0:
unet_spatial_attn_list.append(spatial_attn_dict)
else:
unet_spatial_attn_list.append(None)
else:
unet_crosskv_list.append(None)
unet_spatial_attn_list.append(None)
return embedding_list, text_encoder_list, unet_crosskv_list, unet_spatial_attn_list, concept_list
def merge_kv_in_cross_attention(concept_list, optimize_iters, new_concept_cfg,
tokenizer, text_encoder, unet,
unet_crosskv_list, device):
# crosskv attention layer names
matches = ['attn2.to_k', 'attn2.to_v']
cross_attention_idx = -1
cross_kv_layer_names = []
# the crosskv name should match the order down->mid->up, and record its layer id
for name, _ in unet.down_blocks.named_parameters():
if any([x in name for x in matches]):
if 'to_k' in name:
cross_attention_idx += 1
cross_kv_layer_names.append(
(cross_attention_idx, 'down_blocks.' + name))
cross_kv_layer_names.append(
(cross_attention_idx,
'down_blocks.' + name.replace('to_k', 'to_v')))
else:
pass
for name, _ in unet.mid_block.named_parameters():
if any([x in name for x in matches]):
if 'to_k' in name:
cross_attention_idx += 1
cross_kv_layer_names.append(
(cross_attention_idx, 'mid_block.' + name))
cross_kv_layer_names.append(
(cross_attention_idx,
'mid_block.' + name.replace('to_k', 'to_v')))
else:
pass
for name, _ in unet.up_blocks.named_parameters():
if any([x in name for x in matches]):
if 'to_k' in name:
cross_attention_idx += 1
cross_kv_layer_names.append(
(cross_attention_idx, 'up_blocks.' + name))
cross_kv_layer_names.append(
(cross_attention_idx,
'up_blocks.' + name.replace('to_k', 'to_v')))
else:
pass
logging.info(
f'Unet have {len(cross_kv_layer_names)} linear layer (related to text feature) need to optimize'
)
original_unet_state_dict = unet.state_dict() # original state dict
concept_weights_dict = {}
# step 1: construct prompts for new concept -> extract input/target features
for concept, tuned_state_dict in zip(concept_list, unet_crosskv_list):
for layer_idx, layer_name in cross_kv_layer_names:
# merge params
original_params = original_unet_state_dict[layer_name]
# hard coded here: in unet, self/crosskv attention disable bias parameter
lora_down_name = layer_name.replace('to_k.weight', 'to_k.lora_down.weight').replace('to_v.weight', 'to_v.lora_down.weight')
lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
alpha = concept['unet_alpha']
lora_down_params = tuned_state_dict[lora_down_name].to(device)
lora_up_params = tuned_state_dict[lora_up_name].to(device)
merge_params = original_params + alpha * lora_up_params @ lora_down_params
if layer_name not in concept_weights_dict:
concept_weights_dict[layer_name] = []
concept_weights_dict[layer_name].append(merge_params)
new_kv_weights = {}
# step 3: begin update model
for idx, (layer_idx, layer_name) in enumerate(cross_kv_layer_names):
Wnew = torch.stack(concept_weights_dict[layer_name])
Wnew = torch.mean(Wnew, dim = 0)
new_kv_weights[layer_name] = Wnew
return new_kv_weights
def merge_text_encoder(concept_list, optimize_iters, new_concept_cfg,
tokenizer, text_encoder, text_encoder_list, device):
LoRA_keys = []
for textenc_lora in text_encoder_list:
LoRA_keys += list(textenc_lora.keys())
LoRA_keys = set([
key.replace('.lora_down', '').replace('.lora_up', '')
for key in LoRA_keys
])
text_encoder_layer_names = LoRA_keys
candidate_module_name = [
'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'
]
candidate_module_name = [
name for name in candidate_module_name
if any([name in key for key in LoRA_keys])
]
logging.info(f'text_encoder have {len(text_encoder_layer_names)} linear layer need to optimize')
global module_io_recoder, record_feature
hooker_handlers = []
for name, module in text_encoder.named_modules():
if any([item in name for item in candidate_module_name]):
hooker_handlers.append(module.register_forward_hook(hook=get_hooker(name)))
logging.info(f'add {len(hooker_handlers)} hooker to text_encoder')
original_state_dict = copy.deepcopy(text_encoder.state_dict()) # original state dict
new_concept_input_dict = {}
new_concept_output_dict = {}
concept_weights_dict = {}
for concept, lora_state_dict in zip(concept_list, text_encoder_list):
merged_state_dict = merge_lora_into_weight(
original_state_dict,
lora_state_dict,
text_encoder_layer_names,
model_type='text_encoder',
alpha=concept['text_encoder_alpha'],
device=device)
text_encoder.load_state_dict(merged_state_dict) # load merged parameters
# we use different model to compute new concept feature
for layer_name in text_encoder_layer_names:
if layer_name not in concept_weights_dict:
concept_weights_dict[layer_name] = []
concept_weights_dict[layer_name].append(merged_state_dict[layer_name])
new_text_encoder_weights = {}
# step 3: begin update model
for idx, layer_name in enumerate(text_encoder_layer_names):
Wnew = torch.stack(concept_weights_dict[layer_name])
Wnew = torch.mean(Wnew, dim = 0)
new_text_encoder_weights[layer_name] = Wnew
logging.info(f'remove {len(hooker_handlers)} hooker from text_encoder')
# remove forward hooker
for hook_handle in hooker_handlers:
hook_handle.remove()
return new_text_encoder_weights
@torch.no_grad()
def decode_to_latents(concept_prompt, new_concept_cfg, tokenizer, text_encoder,
unet, test_scheduler, num_inference_steps, device,
record_nums, batch_size):
concept_prompt = bind_concept_prompt([concept_prompt], new_concept_cfg)
text_embeddings = get_text_feature(
concept_prompt,
tokenizer,
text_encoder,
device,
return_type='full_embedding').unsqueeze(0)
text_embeddings = text_embeddings.repeat((batch_size, 1, 1, 1))
# sd 1.x
height = 512
width = 512
latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8), )
latents = latents.to(device, dtype=text_embeddings.dtype)
test_scheduler.set_timesteps(num_inference_steps)
latents = latents * test_scheduler.init_noise_sigma
global record_feature
step = (test_scheduler.timesteps.size(0)) // record_nums
record_timestep = test_scheduler.timesteps[torch.arange(0, test_scheduler.timesteps.size(0), step=step)[:record_nums]]
for t in tqdm(test_scheduler.timesteps):
if t in record_timestep:
record_feature = True
else:
record_feature = False
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = latents
latent_model_input = test_scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# compute the previous noisy sample x_t -> x_t-1
latents = test_scheduler.step(noise_pred, t, latents).prev_sample
return latents, text_embeddings
def merge_spatial_attention(concept_list, optimize_iters, new_concept_cfg, tokenizer, text_encoder, unet, unet_spatial_attn_list, test_scheduler, device):
LoRA_keys = []
for unet_lora in unet_spatial_attn_list:
LoRA_keys += list(unet_lora.keys())
LoRA_keys = set([
key.replace('.lora_down', '').replace('.lora_up', '')
for key in LoRA_keys
])
spatial_attention_layer_names = LoRA_keys
candidate_module_name = [
'attn2.to_q', 'attn2.to_out.0', 'attn1.to_q', 'attn1.to_k',
'attn1.to_v', 'attn1.to_out.0', 'ff.net.2', 'ff.net.0.proj',
'proj_out', 'proj_in'
]
candidate_module_name = [
name for name in candidate_module_name
if any([name in key for key in LoRA_keys])
]
logging.info(
f'unet have {len(spatial_attention_layer_names)} linear layer need to optimize'
)
global module_io_recoder
hooker_handlers = []
for name, module in unet.named_modules():
if any([x in name for x in candidate_module_name]):
hooker_handlers.append(
module.register_forward_hook(hook=get_hooker(name)))
logging.info(f'add {len(hooker_handlers)} hooker to unet')
original_state_dict = copy.deepcopy(unet.state_dict()) # original state dict
revise_edlora_unet_attention_forward(unet)
concept_weights_dict = {}
for concept, tuned_state_dict in zip(concept_list, unet_spatial_attn_list):
# set unet
module_io_recoder = {} # reinit module io recorder
merged_state_dict = merge_lora_into_weight(
original_state_dict,
tuned_state_dict,
spatial_attention_layer_names,
model_type='unet',
alpha=concept['unet_alpha'],
device=device)
unet.load_state_dict(merged_state_dict) # load merged parameters
concept_name = concept['concept_name']
concept_prompt = TEMPLATE_SIMPLE.format(concept_name)
for layer_name in spatial_attention_layer_names:
if layer_name not in concept_weights_dict:
concept_weights_dict[layer_name] = []
concept_weights_dict[layer_name].append(merged_state_dict[layer_name])
new_spatial_attention_weights = {}
# step 5: begin update model
for idx, layer_name in enumerate(spatial_attention_layer_names):
Wnew = torch.stack(concept_weights_dict[layer_name])
Wnew = torch.mean(Wnew, dim = 0)
new_spatial_attention_weights[layer_name] = Wnew
logging.info(f'remove {len(hooker_handlers)} hooker from unet')
for hook_handle in hooker_handlers:
hook_handle.remove()
return new_spatial_attention_weights
def compose_concepts(concept_cfg, optimize_textenc_iters, optimize_unet_iters, pretrained_model_path, save_path, suffix, device):
logging.info('------Step 1: load stable diffusion checkpoint------')
pipe, train_scheduler, test_scheduler = init_stable_diffusion(pretrained_model_path, device)
tokenizer, text_encoder, unet, vae = pipe.tokenizer, pipe.text_encoder, pipe.unet, pipe.vae
for param in itertools.chain(text_encoder.parameters(), unet.parameters(), vae.parameters()):
param.requires_grad = False
logging.info('------Step 2: load new concepts checkpoints------')
embedding_list, text_encoder_list, unet_crosskv_list, unet_spatial_attn_list, concept_list = parse_new_concepts(concept_cfg)
# step 1: inplace add new concept to tokenizer and embedding layers of text encoder
if any([item is not None for item in embedding_list]):
logging.info('------Step 3: merge token embedding------')
_, new_concept_cfg = merge_new_concepts_(embedding_list, concept_list, tokenizer, text_encoder)
else:
_, new_concept_cfg = {}, {}
logging.info('------Step 3: no new embedding, skip merging token embedding------')
# step 2: construct reparameterized text_encoder
if any([item is not None for item in text_encoder_list]):
logging.info('------Step 4: merge text encoder------')
new_text_encoder_weights = merge_text_encoder(
concept_list, optimize_textenc_iters, new_concept_cfg, tokenizer,
text_encoder, text_encoder_list, device)
# update the merged state_dict in text_encoder
text_encoder_state_dict = text_encoder.state_dict()
text_encoder_state_dict.update(new_text_encoder_weights)
text_encoder.load_state_dict(text_encoder_state_dict)
else:
new_text_encoder_weights = {}
logging.info('------Step 4: no new text encoder, skip merging text encoder------')
# step 3: merge unet (k,v in crosskv-attention) params, since they only receive input from text-encoder
if any([item is not None for item in unet_crosskv_list]):
logging.info('------Step 5: merge kv of cross-attention in unet------')
new_kv_weights = merge_kv_in_cross_attention(
concept_list, optimize_textenc_iters, new_concept_cfg,
tokenizer, text_encoder, unet, unet_crosskv_list, device)
# update the merged state_dict in kv of crosskv-attention in Unet
unet_state_dict = unet.state_dict()
unet_state_dict.update(new_kv_weights)
unet.load_state_dict(unet_state_dict)
else:
new_kv_weights = {}
logging.info('------Step 5: no new kv of cross-attention in unet, skip merging kv------')
# step 4: merge unet (q,k,v in self-attention, q in crosskv-attention)
if any([item is not None for item in unet_spatial_attn_list]):
logging.info('------Step 6: merge spatial attention (q in cross-attention, qkv in self-attention) in unet------')
new_spatial_attention_weights = merge_spatial_attention(
concept_list, optimize_unet_iters, new_concept_cfg, tokenizer,
text_encoder, unet, unet_spatial_attn_list, test_scheduler, device)
unet_state_dict = unet.state_dict()
unet_state_dict.update(new_spatial_attention_weights)
unet.load_state_dict(unet_state_dict)
else:
new_spatial_attention_weights = {}
logging.info('------Step 6: no new spatial-attention in unet, skip merging spatial attention------')
checkpoint_save_path = f'{save_path}/combined_model_{suffix}'
pipe.save_pretrained(checkpoint_save_path)
with open(os.path.join(checkpoint_save_path, 'new_concept_cfg.json'), 'w') as json_file:
json.dump(new_concept_cfg, json_file)
def parse_args():
parser = argparse.ArgumentParser('', add_help=False)
parser.add_argument('--concept_cfg', help='json file for multi-concept', required=True, type=str)
parser.add_argument('--save_path', help='folder name to save optimized weights', required=True, type=str)
parser.add_argument('--suffix', help='suffix name', default='base', type=str)
parser.add_argument('--pretrained_models', required=True, type=str)
parser.add_argument('--optimize_unet_iters', default=50, type=int)
parser.add_argument('--optimize_textenc_iters', default=500, type=int)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
# s1: set logger
exp_dir = f'{args.save_path}'
os.makedirs(exp_dir, exist_ok=True)
log_file = f'{exp_dir}/combined_model_{args.suffix}.log'
set_logger(log_file=log_file)
logging.info(args)
compose_concepts(args.concept_cfg,
args.optimize_textenc_iters,
args.optimize_unet_iters,
args.pretrained_models,
args.save_path,
args.suffix,
device='cuda')