ujin-song commited on
Commit
b7e867a
1 Parent(s): 2fe52f3

upload .py files at root dir

Browse files
regionally_controlable_sampling.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import hashlib
3
+ import json
4
+ import os.path
5
+
6
+ import torch
7
+ from diffusers import DPMSolverMultistepScheduler
8
+ from diffusers.models import T2IAdapter
9
+ from PIL import Image
10
+
11
+ from mixofshow.pipelines.pipeline_regionally_t2iadapter import RegionallyT2IAdapterPipeline
12
+
13
+
14
+ def sample_image(pipe,
15
+ input_prompt,
16
+ input_neg_prompt=None,
17
+ generator=None,
18
+ num_inference_steps=50,
19
+ guidance_scale=7.5,
20
+ sketch_adaptor_weight=1.0,
21
+ region_sketch_adaptor_weight='',
22
+ keypose_adaptor_weight=1.0,
23
+ region_keypose_adaptor_weight='',
24
+ **extra_kargs
25
+ ):
26
+
27
+ keypose_condition = extra_kargs.pop('keypose_condition')
28
+ if keypose_condition is not None:
29
+ keypose_adapter_input = [keypose_condition] * len(input_prompt)
30
+ else:
31
+ keypose_adapter_input = None
32
+
33
+ sketch_condition = extra_kargs.pop('sketch_condition')
34
+ if sketch_condition is not None:
35
+ sketch_adapter_input = [sketch_condition] * len(input_prompt)
36
+ else:
37
+ sketch_adapter_input = None
38
+
39
+ images = pipe(
40
+ prompt=input_prompt,
41
+ negative_prompt=input_neg_prompt,
42
+ keypose_adapter_input=keypose_adapter_input,
43
+ keypose_adaptor_weight=keypose_adaptor_weight,
44
+ region_keypose_adaptor_weight=region_keypose_adaptor_weight,
45
+ sketch_adapter_input=sketch_adapter_input,
46
+ sketch_adaptor_weight=sketch_adaptor_weight,
47
+ region_sketch_adaptor_weight=region_sketch_adaptor_weight,
48
+ generator=generator,
49
+ guidance_scale=guidance_scale,
50
+ num_inference_steps=num_inference_steps,
51
+ **extra_kargs).images
52
+ return images
53
+
54
+
55
+ def build_model(pretrained_model, device):
56
+ pipe = RegionallyT2IAdapterPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16).to(device)
57
+ assert os.path.exists(os.path.join(pretrained_model, 'new_concept_cfg.json'))
58
+ with open(os.path.join(pretrained_model, 'new_concept_cfg.json'), 'r') as json_file:
59
+ new_concept_cfg = json.load(json_file)
60
+ pipe.set_new_concept_cfg(new_concept_cfg)
61
+ pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(pretrained_model, subfolder='scheduler')
62
+ pipe.keypose_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_openpose_sd14v1', torch_dtype=torch.float16).to(device)
63
+ pipe.sketch_adapter = T2IAdapter.from_pretrained('TencentARC/t2iadapter_sketch_sd14v1', torch_dtype=torch.float16).to(device)
64
+ return pipe
65
+
66
+
67
+ def prepare_text(prompt, region_prompts, height, width):
68
+ '''
69
+ Args:
70
+ prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
71
+ Returns:
72
+ full_prompt: subject1, attribute1 and subject2, attribute2, global text
73
+ context_prompt: subject1 and subject2, global text
74
+ entity_collection: [(subject1, attribute1), Location1]
75
+ '''
76
+ region_collection = []
77
+
78
+ regions = region_prompts.split('|')
79
+
80
+ for region in regions:
81
+ if region == '':
82
+ break
83
+ prompt_region, neg_prompt_region, pos = region.split('-*-')
84
+ prompt_region = prompt_region.replace('[', '').replace(']', '')
85
+ neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
86
+ pos = eval(pos)
87
+ if len(pos) == 0:
88
+ pos = [0, 0, 1, 1]
89
+ else:
90
+ pos[0], pos[2] = pos[0] / height, pos[2] / height
91
+ pos[1], pos[3] = pos[1] / width, pos[3] / width
92
+
93
+ region_collection.append((prompt_region, neg_prompt_region, pos))
94
+ return (prompt, region_collection)
95
+
96
+
97
+ def parse_args():
98
+ parser = argparse.ArgumentParser('', add_help=False)
99
+ parser.add_argument('--pretrained_model', default='experiments/composed_edlora/anythingv4/hina+kario+tezuka+mitsuha+son_anythingv4/combined_model_base', type=str)
100
+ parser.add_argument('--sketch_condition', default=None, type=str)
101
+ parser.add_argument('--sketch_adaptor_weight', default=1.0, type=float)
102
+ parser.add_argument('--region_sketch_adaptor_weight', default='', type=str)
103
+ parser.add_argument('--keypose_condition', default=None, type=str)
104
+ parser.add_argument('--keypose_adaptor_weight', default=1.0, type=float)
105
+ parser.add_argument('--region_keypose_adaptor_weight', default='', type=str)
106
+ parser.add_argument('--save_dir', default=None, type=str)
107
+ parser.add_argument('--prompt', default='photo of a toy', type=str)
108
+ parser.add_argument('--negative_prompt', default='', type=str)
109
+ parser.add_argument('--prompt_rewrite', default='', type=str)
110
+ parser.add_argument('--seed', default=16141, type=int)
111
+ parser.add_argument('--suffix', default='', type=str)
112
+ return parser.parse_args()
113
+
114
+
115
+ if __name__ == '__main__':
116
+ args = parse_args()
117
+
118
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
119
+ pipe = build_model(args.pretrained_model, device)
120
+
121
+ if args.sketch_condition is not None and os.path.exists(args.sketch_condition):
122
+ sketch_condition = Image.open(args.sketch_condition).convert('L')
123
+ width_sketch, height_sketch = sketch_condition.size
124
+ print('use sketch condition')
125
+ else:
126
+ sketch_condition, width_sketch, height_sketch = None, 0, 0
127
+ print('skip sketch condition')
128
+
129
+ if args.keypose_condition is not None and os.path.exists(args.keypose_condition):
130
+ keypose_condition = Image.open(args.keypose_condition).convert('RGB')
131
+ width_pose, height_pose = keypose_condition.size
132
+ print('use pose condition')
133
+ else:
134
+ keypose_condition, width_pose, height_pose = None, 0, 0
135
+ print('skip pose condition')
136
+
137
+ if width_sketch != 0 and width_pose != 0:
138
+ assert width_sketch == width_pose and height_sketch == height_pose, 'conditions should be same size'
139
+ width, height = max(width_pose, width_sketch), max(height_pose, height_sketch)
140
+
141
+ kwargs = {
142
+ 'sketch_condition': sketch_condition,
143
+ 'keypose_condition': keypose_condition,
144
+ 'height': height,
145
+ 'width': width,
146
+ }
147
+
148
+ prompts = [args.prompt]
149
+ prompts_rewrite = [args.prompt_rewrite]
150
+ input_prompt = [prepare_text(p, p_w, height, width) for p, p_w in zip(prompts, prompts_rewrite)]
151
+ save_prompt = input_prompt[0][0]
152
+
153
+ image = sample_image(
154
+ pipe,
155
+ input_prompt=input_prompt,
156
+ input_neg_prompt=[args.negative_prompt] * len(input_prompt),
157
+ generator=torch.Generator(device).manual_seed(args.seed),
158
+ sketch_adaptor_weight=args.sketch_adaptor_weight,
159
+ region_sketch_adaptor_weight=args.region_sketch_adaptor_weight,
160
+ keypose_adaptor_weight=args.keypose_adaptor_weight,
161
+ region_keypose_adaptor_weight=args.region_keypose_adaptor_weight,
162
+ **kwargs)
163
+
164
+ print(f'save to: {args.save_dir}')
165
+
166
+ configs = [
167
+ f'pretrained_model: {args.pretrained_model}\n',
168
+ f'context_prompt: {args.prompt}\n', f'neg_context_prompt: {args.negative_prompt}\n',
169
+ f'sketch_condition: {args.sketch_condition}\n', f'sketch_adaptor_weight: {args.sketch_adaptor_weight}\n',
170
+ f'region_sketch_adaptor_weight: {args.region_sketch_adaptor_weight}\n',
171
+ f'keypose_condition: {args.keypose_condition}\n', f'keypose_adaptor_weight: {args.keypose_adaptor_weight}\n',
172
+ f'region_keypose_adaptor_weight: {args.region_keypose_adaptor_weight}\n', f'random seed: {args.seed}\n',
173
+ f'prompt_rewrite: {args.prompt_rewrite}\n'
174
+ ]
175
+ hash_code = hashlib.sha256(''.join(configs).encode('utf-8')).hexdigest()[:8]
176
+
177
+ save_prompt = save_prompt.replace(' ', '_')
178
+ # save_name = f'{save_prompt}---{args.suffix}---{hash_code}.png'
179
+ # save_dir = os.path.join(args.save_dir, f'seed_{args.seed}')
180
+ save_name = f'{save_prompt}---{args.suffix}(seed{args.seed})---{hash_code}.png'
181
+ save_dir = args.save_dir
182
+ save_path = os.path.join(save_dir, save_name)
183
+ save_config_path = os.path.join(save_dir, save_name.replace('.png', '.txt'))
184
+
185
+ os.makedirs(save_dir, exist_ok=True)
186
+ image[0].save(os.path.join(save_dir, save_name))
187
+
188
+ with open(save_config_path, 'w') as fw:
189
+ fw.writelines(configs)
test_edlora.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import os.path as osp
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from accelerate import Accelerator
8
+ from accelerate.logging import get_logger
9
+ from accelerate.utils import set_seed
10
+ from diffusers import DPMSolverMultistepScheduler
11
+ from diffusers.utils import check_min_version
12
+ from omegaconf import OmegaConf
13
+ from tqdm import tqdm
14
+
15
+ from mixofshow.data.prompt_dataset import PromptDataset
16
+ from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline
17
+ from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora
18
+ from mixofshow.utils.util import NEGATIVE_PROMPT, compose_visualize, dict2str, pil_imwrite, set_path_logger
19
+
20
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
21
+ check_min_version('0.18.2')
22
+
23
+
24
+ def visual_validation(accelerator, pipe, dataloader, current_iter, opt):
25
+ dataset_name = dataloader.dataset.opt['name']
26
+ pipe.unet.eval()
27
+ pipe.text_encoder.eval()
28
+
29
+ for idx, val_data in enumerate(tqdm(dataloader)):
30
+ output = pipe(
31
+ prompt=val_data['prompts'],
32
+ latents=val_data['latents'].to(dtype=torch.float16),
33
+ negative_prompt=[NEGATIVE_PROMPT] * len(val_data['prompts']),
34
+ num_inference_steps=opt['val']['sample'].get('num_inference_steps', 50),
35
+ guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
36
+ ).images
37
+
38
+ for img, prompt, indice in zip(output, val_data['prompts'], val_data['indices']):
39
+ img_name = '{prompt}---G_{guidance_scale}_S_{steps}---{indice}'.format(
40
+ prompt=prompt.replace(' ', '_'),
41
+ guidance_scale=opt['val']['sample'].get('guidance_scale', 7.5),
42
+ steps=opt['val']['sample'].get('num_inference_steps', 50),
43
+ indice=indice)
44
+
45
+ save_img_path = osp.join(opt['path']['visualization'], dataset_name, f'{current_iter}', f'{img_name}---{current_iter}.png')
46
+
47
+ pil_imwrite(img, save_img_path)
48
+ # tentative for out of GPU memory
49
+ del output
50
+ torch.cuda.empty_cache()
51
+
52
+ # Save the lora layers, final eval
53
+ accelerator.wait_for_everyone()
54
+
55
+ if opt['val'].get('compose_visualize'):
56
+ if accelerator.is_main_process:
57
+ compose_visualize(os.path.dirname(save_img_path))
58
+
59
+
60
+ def test(root_path, args):
61
+
62
+ # load config
63
+ opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True)
64
+
65
+ # set accelerator, mix-precision set in the environment by "accelerate config"
66
+ accelerator = Accelerator(mixed_precision=opt['mixed_precision'])
67
+
68
+ # set experiment dir
69
+ with accelerator.main_process_first():
70
+ set_path_logger(accelerator, root_path, args.opt, opt, is_train=False)
71
+
72
+ # get logger
73
+ logger = get_logger('mixofshow', log_level='INFO')
74
+ logger.info(accelerator.state, main_process_only=True)
75
+
76
+ logger.info(dict2str(opt))
77
+
78
+ # If passed along, set the training seed now.
79
+ if opt.get('manual_seed') is not None:
80
+ set_seed(opt['manual_seed'])
81
+
82
+ # Get the training dataset
83
+ valset_cfg = opt['datasets']['val_vis']
84
+ val_dataset = PromptDataset(valset_cfg)
85
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False)
86
+
87
+ enable_edlora = opt['models']['enable_edlora']
88
+
89
+ for lora_alpha in opt['val']['alpha_list']:
90
+ pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline
91
+ pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'],
92
+ scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'),
93
+ torch_dtype=torch.float16).to('cuda')
94
+ pipe, new_concept_cfg = convert_edlora(pipe, torch.load(opt['path']['lora_path']), enable_edlora=enable_edlora, alpha=lora_alpha)
95
+ pipe.set_new_concept_cfg(new_concept_cfg)
96
+ # visualize embedding + LoRA weight shift
97
+ logger.info(f'Start validation sample lora({lora_alpha}):')
98
+
99
+ lora_type = 'edlora' if enable_edlora else 'lora'
100
+ visual_validation(accelerator, pipe, val_dataloader, f'validation_{lora_type}_{lora_alpha}', opt)
101
+ del pipe
102
+
103
+
104
+ if __name__ == '__main__':
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument('-opt', type=str, default='options/test/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml')
107
+ args = parser.parse_args()
108
+
109
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
110
+ test(root_path, args)
train_edlora.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import os
4
+ import os.path as osp
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from accelerate import Accelerator
9
+ from accelerate.logging import get_logger
10
+ from accelerate.utils import set_seed
11
+ from diffusers import DPMSolverMultistepScheduler
12
+ from diffusers.optimization import get_scheduler
13
+ from diffusers.utils import check_min_version
14
+ from omegaconf import OmegaConf
15
+
16
+ from mixofshow.data.lora_dataset import LoraDataset
17
+ from mixofshow.data.prompt_dataset import PromptDataset
18
+ from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline
19
+ from mixofshow.pipelines.trainer_edlora import EDLoRATrainer
20
+ from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora
21
+ from mixofshow.utils.util import MessageLogger, dict2str, reduce_loss_dict, set_path_logger
22
+ from test_edlora import visual_validation
23
+
24
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
25
+ check_min_version('0.18.2')
26
+
27
+
28
+ def train(root_path, args):
29
+
30
+ # load config
31
+ opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True)
32
+
33
+ # set accelerator, mix-precision set in the environment by "accelerate config"
34
+ accelerator = Accelerator(mixed_precision=opt['mixed_precision'], gradient_accumulation_steps=opt['gradient_accumulation_steps'])
35
+
36
+ # set experiment dir
37
+ with accelerator.main_process_first():
38
+ set_path_logger(accelerator, root_path, args.opt, opt, is_train=True)
39
+
40
+ # get logger
41
+ logger = get_logger('mixofshow', log_level='INFO')
42
+ logger.info(accelerator.state, main_process_only=True)
43
+
44
+ logger.info(dict2str(opt))
45
+
46
+ # If passed along, set the training seed now.
47
+ if opt.get('manual_seed') is not None:
48
+ set_seed(opt['manual_seed'])
49
+
50
+ # Load model
51
+ EDLoRA_trainer = EDLoRATrainer(**opt['models'])
52
+
53
+ # set optimizer
54
+ train_opt = opt['train']
55
+ optim_type = train_opt['optim_g'].pop('type')
56
+ assert optim_type == 'AdamW', 'only support AdamW now'
57
+ optimizer = torch.optim.AdamW(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'])
58
+
59
+ # Get the training dataset
60
+ trainset_cfg = opt['datasets']['train']
61
+ train_dataset = LoraDataset(trainset_cfg)
62
+ train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainset_cfg['batch_size_per_gpu'], shuffle=True, drop_last=True)
63
+
64
+ # Get the training dataset
65
+ valset_cfg = opt['datasets']['val_vis']
66
+ val_dataset = PromptDataset(valset_cfg)
67
+ val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False)
68
+
69
+ # Prepare everything with our `accelerator`.
70
+ EDLoRA_trainer, optimizer, train_dataloader, val_dataloader = accelerator.prepare(EDLoRA_trainer, optimizer, train_dataloader, val_dataloader)
71
+
72
+ # Train!
73
+ total_batch_size = opt['datasets']['train']['batch_size_per_gpu'] * accelerator.num_processes * opt['gradient_accumulation_steps']
74
+ total_iter = len(train_dataset) / total_batch_size
75
+ opt['train']['total_iter'] = total_iter
76
+
77
+ logger.info('***** Running training *****')
78
+ logger.info(f' Num examples = {len(train_dataset)}')
79
+ logger.info(f" Instantaneous batch size per device = {opt['datasets']['train']['batch_size_per_gpu']}")
80
+ logger.info(f' Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
81
+ logger.info(f' Total optimization steps = {total_iter}')
82
+ global_step = 0
83
+
84
+ # Scheduler
85
+ lr_scheduler = get_scheduler(
86
+ 'linear',
87
+ optimizer=optimizer,
88
+ num_warmup_steps=0,
89
+ num_training_steps=total_iter * opt['gradient_accumulation_steps'],
90
+ )
91
+
92
+ def make_data_yielder(dataloader):
93
+ while True:
94
+ for batch in dataloader:
95
+ yield batch
96
+ accelerator.wait_for_everyone()
97
+
98
+ train_data_yielder = make_data_yielder(train_dataloader)
99
+
100
+ msg_logger = MessageLogger(opt, global_step)
101
+ stop_emb_update = False
102
+
103
+ original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight)
104
+
105
+ while global_step < opt['train']['total_iter']:
106
+ with accelerator.accumulate(EDLoRA_trainer):
107
+
108
+ accelerator.unwrap_model(EDLoRA_trainer).unet.train()
109
+ accelerator.unwrap_model(EDLoRA_trainer).text_encoder.train()
110
+ loss_dict = {}
111
+
112
+ batch = next(train_data_yielder)
113
+
114
+ if 'masks' in batch:
115
+ masks = batch['masks']
116
+ else:
117
+ masks = batch['img_masks']
118
+
119
+ loss = EDLoRA_trainer(batch['images'], batch['prompts'], masks, batch['img_masks'])
120
+ loss_dict['loss'] = loss
121
+
122
+ # get fix embedding and learn embedding
123
+ index_no_updates = torch.arange(len(accelerator.unwrap_model(EDLoRA_trainer).tokenizer)) != -1
124
+ if not stop_emb_update:
125
+ for token_id in accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids():
126
+ index_no_updates[token_id] = False
127
+
128
+ accelerator.backward(loss)
129
+ optimizer.step()
130
+ lr_scheduler.step()
131
+ optimizer.zero_grad()
132
+
133
+ if accelerator.sync_gradients:
134
+ # set no update token to origin
135
+ token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight
136
+ token_embeds.data[index_no_updates, :] = original_embedding.data[index_no_updates, :]
137
+
138
+ token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight
139
+ concept_token_ids = accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids()
140
+ loss_dict['Norm_mean'] = token_embeds[concept_token_ids].norm(dim=-1).mean()
141
+ if stop_emb_update is False and float(loss_dict['Norm_mean']) >= train_opt.get('emb_norm_threshold', 5.5e-1):
142
+ stop_emb_update = True
143
+ original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight)
144
+
145
+ log_dict = reduce_loss_dict(accelerator, loss_dict)
146
+
147
+ # Checks if the accelerator has performed an optimization step behind the scenes
148
+ if accelerator.sync_gradients:
149
+ global_step += 1
150
+
151
+ if global_step % opt['logger']['print_freq'] == 0:
152
+ log_vars = {'iter': global_step}
153
+ log_vars.update({'lrs': lr_scheduler.get_last_lr()})
154
+ log_vars.update(log_dict)
155
+ msg_logger(log_vars)
156
+
157
+ if global_step % opt['logger']['save_checkpoint_freq'] == 0:
158
+ save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, global_step, logger)
159
+
160
+ # Save the lora layers, final eval
161
+ accelerator.wait_for_everyone()
162
+ save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, 'latest', logger)
163
+
164
+
165
+ def save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, global_step, logger):
166
+ enable_edlora = opt['models']['enable_edlora']
167
+ lora_type = 'edlora' if enable_edlora else 'lora'
168
+ save_path = os.path.join(opt['path']['models'], f'{lora_type}_model-{global_step}.pth')
169
+
170
+ if accelerator.is_main_process:
171
+ accelerator.save({'params': accelerator.unwrap_model(EDLoRA_trainer).delta_state_dict()}, save_path)
172
+ logger.info(f'Save state to {save_path}')
173
+
174
+ accelerator.wait_for_everyone()
175
+
176
+ if opt['val']['val_during_save']:
177
+ logger.info(f'Start validation {save_path}:')
178
+ for lora_alpha in opt['val']['alpha_list']:
179
+ pipeclass = EDLoRAPipeline if enable_edlora else StableDiffusionPipeline
180
+
181
+ pipe = pipeclass.from_pretrained(opt['models']['pretrained_path'],
182
+ scheduler=DPMSolverMultistepScheduler.from_pretrained(opt['models']['pretrained_path'], subfolder='scheduler'),
183
+ torch_dtype=torch.float16).to('cuda')
184
+ pipe, new_concept_cfg = convert_edlora(pipe, torch.load(save_path), enable_edlora=enable_edlora, alpha=lora_alpha)
185
+ pipe.set_new_concept_cfg(new_concept_cfg)
186
+ pipe.set_progress_bar_config(disable=True)
187
+ visual_validation(accelerator, pipe, val_dataloader, f'Iters-{global_step}_Alpha-{lora_alpha}', opt)
188
+
189
+ del pipe
190
+
191
+
192
+ if __name__ == '__main__':
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument('-opt', type=str, default='options/train/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml')
195
+ args = parser.parse_args()
196
+
197
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
198
+ train(root_path, args)
weight_fusion.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import itertools
4
+ import json
5
+ import logging
6
+ import os
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.optim as optim
12
+ from diffusers import DDPMScheduler, DPMSolverMultistepScheduler, StableDiffusionPipeline
13
+ from tqdm import tqdm
14
+
15
+ from mixofshow.models.edlora import revise_edlora_unet_attention_forward
16
+ from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
17
+ from mixofshow.utils.util import set_logger
18
+
19
+ TEMPLATE_SIMPLE = 'photo of a {}'
20
+
21
+
22
+ def chunk_compute_mse(K_target, V_target, W, device, chunk_size=5000):
23
+ num_chunks = (K_target.size(0) + chunk_size - 1) // chunk_size
24
+
25
+ loss = 0
26
+
27
+ for i in range(num_chunks):
28
+ # Extract the current chunk
29
+ start_idx = i * chunk_size
30
+ end_idx = min(start_idx + chunk_size, K_target.size(0))
31
+ loss += F.mse_loss(
32
+ F.linear(K_target[start_idx:end_idx].to(device), W),
33
+ V_target[start_idx:end_idx].to(device)) * (end_idx - start_idx)
34
+ loss /= K_target.size(0)
35
+ return loss
36
+
37
+
38
+ def update_quasi_newton(K_target, V_target, W, iters, device):
39
+ '''
40
+ Args:
41
+ K: torch.Tensor, size [n_samples, n_features]
42
+ V: torch.Tensor, size [n_samples, n_targets]
43
+ K_target: torch.Tensor, size [n_constraints, n_features]
44
+ V_target: torch.Tensor, size [n_constraints, n_targets]
45
+ W: torch.Tensor, size [n_targets, n_features]
46
+
47
+ Returns:
48
+ Wnew: torch.Tensor, size [n_targets, n_features]
49
+ '''
50
+
51
+ W = W.detach()
52
+ V_target = V_target.detach()
53
+ K_target = K_target.detach()
54
+
55
+ W.requires_grad = True
56
+ K_target.requires_grad = False
57
+ V_target.requires_grad = False
58
+
59
+ best_loss = np.Inf
60
+ best_W = None
61
+
62
+ def closure():
63
+ nonlocal best_W, best_loss
64
+ optimizer.zero_grad()
65
+
66
+ if len(W.shape) == 4:
67
+ loss = F.mse_loss(F.conv2d(K_target.to(device), W),
68
+ V_target.to(device))
69
+ else:
70
+ loss = chunk_compute_mse(K_target, V_target, W, device)
71
+
72
+ if loss < best_loss:
73
+ best_loss = loss
74
+ best_W = W.clone().cpu()
75
+ loss.backward()
76
+ return loss
77
+
78
+ optimizer = optim.LBFGS([W],
79
+ lr=1,
80
+ max_iter=iters,
81
+ history_size=25,
82
+ line_search_fn='strong_wolfe',
83
+ tolerance_grad=1e-16,
84
+ tolerance_change=1e-16)
85
+ optimizer.step(closure)
86
+
87
+ with torch.no_grad():
88
+ if len(W.shape) == 4:
89
+ loss = torch.norm(
90
+ F.conv2d(K_target.to(torch.float32), best_W.to(torch.float32)) - V_target.to(torch.float32), 2, dim=1)
91
+ else:
92
+ loss = torch.norm(
93
+ F.linear(K_target.to(torch.float32), best_W.to(torch.float32)) - V_target.to(torch.float32), 2, dim=1)
94
+
95
+ logging.info('new_concept loss: %e' % loss.mean().item())
96
+ return best_W
97
+
98
+
99
+ def merge_lora_into_weight(original_state_dict, lora_state_dict, modification_layer_names, model_type, alpha, device):
100
+ def get_lora_down_name(original_layer_name):
101
+ if model_type == 'text_encoder':
102
+ lora_down_name = original_layer_name.replace('q_proj.weight', 'q_proj.lora_down.weight') \
103
+ .replace('k_proj.weight', 'k_proj.lora_down.weight') \
104
+ .replace('v_proj.weight', 'v_proj.lora_down.weight') \
105
+ .replace('out_proj.weight', 'out_proj.lora_down.weight') \
106
+ .replace('fc1.weight', 'fc1.lora_down.weight') \
107
+ .replace('fc2.weight', 'fc2.lora_down.weight')
108
+ else:
109
+ lora_down_name = k.replace('to_q.weight', 'to_q.lora_down.weight') \
110
+ .replace('to_k.weight', 'to_k.lora_down.weight') \
111
+ .replace('to_v.weight', 'to_v.lora_down.weight') \
112
+ .replace('to_out.0.weight', 'to_out.0.lora_down.weight') \
113
+ .replace('ff.net.0.proj.weight', 'ff.net.0.proj.lora_down.weight') \
114
+ .replace('ff.net.2.weight', 'ff.net.2.lora_down.weight') \
115
+ .replace('proj_out.weight', 'proj_out.lora_down.weight') \
116
+ .replace('proj_in.weight', 'proj_in.lora_down.weight')
117
+
118
+ return lora_down_name
119
+
120
+ assert model_type in ['unet', 'text_encoder']
121
+ new_state_dict = copy.deepcopy(original_state_dict)
122
+ load_cnt = 0
123
+
124
+ for k in modification_layer_names:
125
+ lora_down_name = get_lora_down_name(k)
126
+ lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
127
+
128
+ if lora_up_name in lora_state_dict:
129
+ load_cnt += 1
130
+ original_params = new_state_dict[k]
131
+ lora_down_params = lora_state_dict[lora_down_name].to(device)
132
+ lora_up_params = lora_state_dict[lora_up_name].to(device)
133
+ if len(original_params.shape) == 4:
134
+ lora_param = lora_up_params.squeeze(
135
+ ) @ lora_down_params.squeeze()
136
+ lora_param = lora_param.unsqueeze(-1).unsqueeze(-1)
137
+ else:
138
+ lora_param = lora_up_params @ lora_down_params
139
+ merge_params = original_params + alpha * lora_param
140
+ new_state_dict[k] = merge_params
141
+
142
+ logging.info(f'load {load_cnt} LoRAs of {model_type}')
143
+ return new_state_dict
144
+
145
+
146
+ module_io_recoder = {}
147
+ record_feature = False # remember to set record feature
148
+
149
+
150
+ def get_hooker(module_name):
151
+ def hook(module, feature_in, feature_out):
152
+ if module_name not in module_io_recoder:
153
+ module_io_recoder[module_name] = {'input': [], 'output': []}
154
+ if record_feature:
155
+ module_io_recoder[module_name]['input'].append(feature_in[0].cpu())
156
+ if module.bias is not None:
157
+ if len(feature_out.shape) == 4:
158
+ bias = module.bias.unsqueeze(-1).unsqueeze(-1)
159
+ else:
160
+ bias = module.bias
161
+ module_io_recoder[module_name]['output'].append(
162
+ (feature_out - bias).cpu()) # remove bias
163
+ else:
164
+ module_io_recoder[module_name]['output'].append(
165
+ feature_out.cpu())
166
+
167
+ return hook
168
+
169
+
170
+ def init_stable_diffusion(pretrained_model_path, device):
171
+ # step1: get w0 parameters
172
+ model_id = pretrained_model_path
173
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
174
+
175
+ train_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder='scheduler')
176
+ test_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder='scheduler')
177
+ pipe.safety_checker = None
178
+ pipe.scheduler = test_scheduler
179
+ return pipe, train_scheduler, test_scheduler
180
+
181
+
182
+ @torch.no_grad()
183
+ def get_text_feature(prompts, tokenizer, text_encoder, device, return_type='category_embedding'):
184
+ text_features = []
185
+
186
+ if return_type == 'category_embedding':
187
+ for text in prompts:
188
+ tokens = tokenizer(
189
+ text,
190
+ truncation=True,
191
+ max_length=tokenizer.model_max_length,
192
+ return_length=True,
193
+ return_overflowing_tokens=False,
194
+ padding='do_not_pad',
195
+ ).input_ids
196
+
197
+ new_token_position = torch.where(torch.tensor(tokens) >= 49407)[0]
198
+ # >40497 not include end token | >=40497 include end token
199
+ concept_feature = text_encoder(
200
+ torch.LongTensor(tokens).reshape(
201
+ 1, -1).to(device))[0][:,
202
+ new_token_position].reshape(-1, 768)
203
+ text_features.append(concept_feature)
204
+ return torch.cat(text_features, 0).float()
205
+ elif return_type == 'full_embedding':
206
+ text_input = tokenizer(prompts,
207
+ padding='max_length',
208
+ max_length=tokenizer.model_max_length,
209
+ truncation=True,
210
+ return_tensors='pt')
211
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
212
+ return text_embeddings
213
+ else:
214
+ raise NotImplementedError
215
+
216
+
217
+ def merge_new_concepts_(embedding_list, concept_list, tokenizer, text_encoder):
218
+ def add_new_concept(concept_name, embedding):
219
+ new_token_names = [
220
+ f'<new{start_idx + layer_id}>'
221
+ for layer_id in range(NUM_CROSS_ATTENTION_LAYERS)
222
+ ]
223
+ num_added_tokens = tokenizer.add_tokens(new_token_names)
224
+ assert num_added_tokens == NUM_CROSS_ATTENTION_LAYERS
225
+ new_token_ids = [
226
+ tokenizer.convert_tokens_to_ids(token_name)
227
+ for token_name in new_token_names
228
+ ]
229
+
230
+ text_encoder.resize_token_embeddings(len(tokenizer))
231
+ token_embeds = text_encoder.get_input_embeddings().weight.data
232
+
233
+ token_embeds[new_token_ids] = token_embeds[new_token_ids].copy_(
234
+ embedding[concept_name])
235
+
236
+ embedding_features.update({concept_name: embedding[concept_name]})
237
+ logging.info(
238
+ f'concept {concept_name} is bind with token_id: [{min(new_token_ids)}, {max(new_token_ids)}]'
239
+ )
240
+
241
+ return start_idx + NUM_CROSS_ATTENTION_LAYERS, new_token_ids, new_token_names
242
+
243
+ embedding_features = {}
244
+ new_concept_cfg = {}
245
+
246
+ start_idx = 0
247
+
248
+ NUM_CROSS_ATTENTION_LAYERS = 16
249
+
250
+ for idx, (embedding,
251
+ concept) in enumerate(zip(embedding_list, concept_list)):
252
+ concept_names = concept['concept_name'].split(' ')
253
+
254
+ for concept_name in concept_names:
255
+ if not concept_name.startswith('<'):
256
+ continue
257
+ else:
258
+ assert concept_name in embedding, 'check the config, the provide concept name is not in the lora model'
259
+ start_idx, new_token_ids, new_token_names = add_new_concept(
260
+ concept_name, embedding)
261
+ new_concept_cfg.update({
262
+ concept_name: {
263
+ 'concept_token_ids': new_token_ids,
264
+ 'concept_token_names': new_token_names
265
+ }
266
+ })
267
+ return embedding_features, new_concept_cfg
268
+
269
+
270
+ def parse_new_concepts(concept_cfg):
271
+ with open(concept_cfg, 'r') as f:
272
+ concept_list = json.load(f)
273
+
274
+ model_paths = [concept['lora_path'] for concept in concept_list]
275
+
276
+ embedding_list = []
277
+ text_encoder_list = []
278
+ unet_crosskv_list = []
279
+ unet_spatial_attn_list = []
280
+
281
+ for model_path in model_paths:
282
+ model = torch.load(model_path)['params']
283
+
284
+ if 'new_concept_embedding' in model and len(
285
+ model['new_concept_embedding']) != 0:
286
+ embedding_list.append(model['new_concept_embedding'])
287
+ else:
288
+ embedding_list.append(None)
289
+
290
+ if 'text_encoder' in model and len(model['text_encoder']) != 0:
291
+ text_encoder_list.append(model['text_encoder'])
292
+ else:
293
+ text_encoder_list.append(None)
294
+
295
+ if 'unet' in model and len(model['unet']) != 0:
296
+ crosskv_matches = ['attn2.to_k.lora', 'attn2.to_v.lora']
297
+ crosskv_dict = {
298
+ k: v
299
+ for k, v in model['unet'].items()
300
+ if any([x in k for x in crosskv_matches])
301
+ }
302
+
303
+ if len(crosskv_dict) != 0:
304
+ unet_crosskv_list.append(crosskv_dict)
305
+ else:
306
+ unet_crosskv_list.append(None)
307
+
308
+ spatial_attn_dict = {
309
+ k: v
310
+ for k, v in model['unet'].items()
311
+ if all([x not in k for x in crosskv_matches])
312
+ }
313
+
314
+ if len(spatial_attn_dict) != 0:
315
+ unet_spatial_attn_list.append(spatial_attn_dict)
316
+ else:
317
+ unet_spatial_attn_list.append(None)
318
+ else:
319
+ unet_crosskv_list.append(None)
320
+ unet_spatial_attn_list.append(None)
321
+
322
+ return embedding_list, text_encoder_list, unet_crosskv_list, unet_spatial_attn_list, concept_list
323
+
324
+
325
+ def merge_kv_in_cross_attention(concept_list, optimize_iters, new_concept_cfg,
326
+ tokenizer, text_encoder, unet,
327
+ unet_crosskv_list, device):
328
+ # crosskv attention layer names
329
+ matches = ['attn2.to_k', 'attn2.to_v']
330
+
331
+ cross_attention_idx = -1
332
+ cross_kv_layer_names = []
333
+
334
+ # the crosskv name should match the order down->mid->up, and record its layer id
335
+ for name, _ in unet.down_blocks.named_parameters():
336
+ if any([x in name for x in matches]):
337
+ if 'to_k' in name:
338
+ cross_attention_idx += 1
339
+ cross_kv_layer_names.append(
340
+ (cross_attention_idx, 'down_blocks.' + name))
341
+ cross_kv_layer_names.append(
342
+ (cross_attention_idx,
343
+ 'down_blocks.' + name.replace('to_k', 'to_v')))
344
+ else:
345
+ pass
346
+
347
+ for name, _ in unet.mid_block.named_parameters():
348
+ if any([x in name for x in matches]):
349
+ if 'to_k' in name:
350
+ cross_attention_idx += 1
351
+ cross_kv_layer_names.append(
352
+ (cross_attention_idx, 'mid_block.' + name))
353
+ cross_kv_layer_names.append(
354
+ (cross_attention_idx,
355
+ 'mid_block.' + name.replace('to_k', 'to_v')))
356
+ else:
357
+ pass
358
+
359
+ for name, _ in unet.up_blocks.named_parameters():
360
+ if any([x in name for x in matches]):
361
+ if 'to_k' in name:
362
+ cross_attention_idx += 1
363
+ cross_kv_layer_names.append(
364
+ (cross_attention_idx, 'up_blocks.' + name))
365
+ cross_kv_layer_names.append(
366
+ (cross_attention_idx,
367
+ 'up_blocks.' + name.replace('to_k', 'to_v')))
368
+ else:
369
+ pass
370
+
371
+ logging.info(
372
+ f'Unet have {len(cross_kv_layer_names)} linear layer (related to text feature) need to optimize'
373
+ )
374
+
375
+ original_unet_state_dict = unet.state_dict() # original state dict
376
+ concept_weights_dict = {}
377
+
378
+ # step 1: construct prompts for new concept -> extract input/target features
379
+ for concept, tuned_state_dict in zip(concept_list, unet_crosskv_list):
380
+
381
+ for layer_idx, layer_name in cross_kv_layer_names:
382
+
383
+ # merge params
384
+ original_params = original_unet_state_dict[layer_name]
385
+
386
+ # hard coded here: in unet, self/crosskv attention disable bias parameter
387
+ lora_down_name = layer_name.replace('to_k.weight', 'to_k.lora_down.weight').replace('to_v.weight', 'to_v.lora_down.weight')
388
+ lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
389
+
390
+ alpha = concept['unet_alpha']
391
+
392
+ lora_down_params = tuned_state_dict[lora_down_name].to(device)
393
+ lora_up_params = tuned_state_dict[lora_up_name].to(device)
394
+
395
+ merge_params = original_params + alpha * lora_up_params @ lora_down_params
396
+
397
+ if layer_name not in concept_weights_dict:
398
+ concept_weights_dict[layer_name] = []
399
+
400
+ concept_weights_dict[layer_name].append(merge_params)
401
+
402
+
403
+ new_kv_weights = {}
404
+ # step 3: begin update model
405
+ for idx, (layer_idx, layer_name) in enumerate(cross_kv_layer_names):
406
+ Wnew = torch.stack(concept_weights_dict[layer_name])
407
+ Wnew = torch.mean(Wnew, dim = 0)
408
+ new_kv_weights[layer_name] = Wnew
409
+
410
+ return new_kv_weights
411
+
412
+
413
+ def merge_text_encoder(concept_list, optimize_iters, new_concept_cfg,
414
+ tokenizer, text_encoder, text_encoder_list, device):
415
+
416
+ LoRA_keys = []
417
+ for textenc_lora in text_encoder_list:
418
+ LoRA_keys += list(textenc_lora.keys())
419
+ LoRA_keys = set([
420
+ key.replace('.lora_down', '').replace('.lora_up', '')
421
+ for key in LoRA_keys
422
+ ])
423
+ text_encoder_layer_names = LoRA_keys
424
+
425
+ candidate_module_name = [
426
+ 'q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'
427
+ ]
428
+ candidate_module_name = [
429
+ name for name in candidate_module_name
430
+ if any([name in key for key in LoRA_keys])
431
+ ]
432
+
433
+ logging.info(f'text_encoder have {len(text_encoder_layer_names)} linear layer need to optimize')
434
+
435
+ global module_io_recoder, record_feature
436
+ hooker_handlers = []
437
+ for name, module in text_encoder.named_modules():
438
+ if any([item in name for item in candidate_module_name]):
439
+ hooker_handlers.append(module.register_forward_hook(hook=get_hooker(name)))
440
+
441
+ logging.info(f'add {len(hooker_handlers)} hooker to text_encoder')
442
+
443
+ original_state_dict = copy.deepcopy(text_encoder.state_dict()) # original state dict
444
+
445
+ new_concept_input_dict = {}
446
+ new_concept_output_dict = {}
447
+ concept_weights_dict = {}
448
+
449
+ for concept, lora_state_dict in zip(concept_list, text_encoder_list):
450
+ merged_state_dict = merge_lora_into_weight(
451
+ original_state_dict,
452
+ lora_state_dict,
453
+ text_encoder_layer_names,
454
+ model_type='text_encoder',
455
+ alpha=concept['text_encoder_alpha'],
456
+ device=device)
457
+ text_encoder.load_state_dict(merged_state_dict) # load merged parameters
458
+ # we use different model to compute new concept feature
459
+ for layer_name in text_encoder_layer_names:
460
+ if layer_name not in concept_weights_dict:
461
+ concept_weights_dict[layer_name] = []
462
+ concept_weights_dict[layer_name].append(merged_state_dict[layer_name])
463
+
464
+ new_text_encoder_weights = {}
465
+ # step 3: begin update model
466
+ for idx, layer_name in enumerate(text_encoder_layer_names):
467
+ Wnew = torch.stack(concept_weights_dict[layer_name])
468
+ Wnew = torch.mean(Wnew, dim = 0)
469
+ new_text_encoder_weights[layer_name] = Wnew
470
+
471
+ logging.info(f'remove {len(hooker_handlers)} hooker from text_encoder')
472
+
473
+ # remove forward hooker
474
+ for hook_handle in hooker_handlers:
475
+ hook_handle.remove()
476
+
477
+ return new_text_encoder_weights
478
+
479
+
480
+ @torch.no_grad()
481
+ def decode_to_latents(concept_prompt, new_concept_cfg, tokenizer, text_encoder,
482
+ unet, test_scheduler, num_inference_steps, device,
483
+ record_nums, batch_size):
484
+
485
+ concept_prompt = bind_concept_prompt([concept_prompt], new_concept_cfg)
486
+ text_embeddings = get_text_feature(
487
+ concept_prompt,
488
+ tokenizer,
489
+ text_encoder,
490
+ device,
491
+ return_type='full_embedding').unsqueeze(0)
492
+
493
+ text_embeddings = text_embeddings.repeat((batch_size, 1, 1, 1))
494
+
495
+ # sd 1.x
496
+ height = 512
497
+ width = 512
498
+
499
+ latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8), )
500
+ latents = latents.to(device, dtype=text_embeddings.dtype)
501
+
502
+ test_scheduler.set_timesteps(num_inference_steps)
503
+ latents = latents * test_scheduler.init_noise_sigma
504
+
505
+ global record_feature
506
+ step = (test_scheduler.timesteps.size(0)) // record_nums
507
+ record_timestep = test_scheduler.timesteps[torch.arange(0, test_scheduler.timesteps.size(0), step=step)[:record_nums]]
508
+
509
+ for t in tqdm(test_scheduler.timesteps):
510
+
511
+ if t in record_timestep:
512
+ record_feature = True
513
+ else:
514
+ record_feature = False
515
+
516
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
517
+ latent_model_input = latents
518
+ latent_model_input = test_scheduler.scale_model_input(latent_model_input, t)
519
+
520
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
521
+
522
+ # compute the previous noisy sample x_t -> x_t-1
523
+ latents = test_scheduler.step(noise_pred, t, latents).prev_sample
524
+
525
+ return latents, text_embeddings
526
+
527
+
528
+ def merge_spatial_attention(concept_list, optimize_iters, new_concept_cfg, tokenizer, text_encoder, unet, unet_spatial_attn_list, test_scheduler, device):
529
+ LoRA_keys = []
530
+ for unet_lora in unet_spatial_attn_list:
531
+ LoRA_keys += list(unet_lora.keys())
532
+ LoRA_keys = set([
533
+ key.replace('.lora_down', '').replace('.lora_up', '')
534
+ for key in LoRA_keys
535
+ ])
536
+ spatial_attention_layer_names = LoRA_keys
537
+
538
+ candidate_module_name = [
539
+ 'attn2.to_q', 'attn2.to_out.0', 'attn1.to_q', 'attn1.to_k',
540
+ 'attn1.to_v', 'attn1.to_out.0', 'ff.net.2', 'ff.net.0.proj',
541
+ 'proj_out', 'proj_in'
542
+ ]
543
+ candidate_module_name = [
544
+ name for name in candidate_module_name
545
+ if any([name in key for key in LoRA_keys])
546
+ ]
547
+
548
+ logging.info(
549
+ f'unet have {len(spatial_attention_layer_names)} linear layer need to optimize'
550
+ )
551
+
552
+ global module_io_recoder
553
+ hooker_handlers = []
554
+ for name, module in unet.named_modules():
555
+ if any([x in name for x in candidate_module_name]):
556
+ hooker_handlers.append(
557
+ module.register_forward_hook(hook=get_hooker(name)))
558
+
559
+ logging.info(f'add {len(hooker_handlers)} hooker to unet')
560
+
561
+ original_state_dict = copy.deepcopy(unet.state_dict()) # original state dict
562
+ revise_edlora_unet_attention_forward(unet)
563
+
564
+ concept_weights_dict = {}
565
+
566
+ for concept, tuned_state_dict in zip(concept_list, unet_spatial_attn_list):
567
+ # set unet
568
+ module_io_recoder = {} # reinit module io recorder
569
+
570
+ merged_state_dict = merge_lora_into_weight(
571
+ original_state_dict,
572
+ tuned_state_dict,
573
+ spatial_attention_layer_names,
574
+ model_type='unet',
575
+ alpha=concept['unet_alpha'],
576
+ device=device)
577
+ unet.load_state_dict(merged_state_dict) # load merged parameters
578
+
579
+ concept_name = concept['concept_name']
580
+ concept_prompt = TEMPLATE_SIMPLE.format(concept_name)
581
+
582
+
583
+ for layer_name in spatial_attention_layer_names:
584
+ if layer_name not in concept_weights_dict:
585
+ concept_weights_dict[layer_name] = []
586
+
587
+ concept_weights_dict[layer_name].append(merged_state_dict[layer_name])
588
+
589
+ new_spatial_attention_weights = {}
590
+ # step 5: begin update model
591
+ for idx, layer_name in enumerate(spatial_attention_layer_names):
592
+ Wnew = torch.stack(concept_weights_dict[layer_name])
593
+ Wnew = torch.mean(Wnew, dim = 0)
594
+ new_spatial_attention_weights[layer_name] = Wnew
595
+
596
+ logging.info(f'remove {len(hooker_handlers)} hooker from unet')
597
+
598
+ for hook_handle in hooker_handlers:
599
+ hook_handle.remove()
600
+
601
+ return new_spatial_attention_weights
602
+
603
+
604
+ def compose_concepts(concept_cfg, optimize_textenc_iters, optimize_unet_iters, pretrained_model_path, save_path, suffix, device):
605
+ logging.info('------Step 1: load stable diffusion checkpoint------')
606
+ pipe, train_scheduler, test_scheduler = init_stable_diffusion(pretrained_model_path, device)
607
+ tokenizer, text_encoder, unet, vae = pipe.tokenizer, pipe.text_encoder, pipe.unet, pipe.vae
608
+ for param in itertools.chain(text_encoder.parameters(), unet.parameters(), vae.parameters()):
609
+ param.requires_grad = False
610
+
611
+ logging.info('------Step 2: load new concepts checkpoints------')
612
+ embedding_list, text_encoder_list, unet_crosskv_list, unet_spatial_attn_list, concept_list = parse_new_concepts(concept_cfg)
613
+
614
+ # step 1: inplace add new concept to tokenizer and embedding layers of text encoder
615
+ if any([item is not None for item in embedding_list]):
616
+ logging.info('------Step 3: merge token embedding------')
617
+ _, new_concept_cfg = merge_new_concepts_(embedding_list, concept_list, tokenizer, text_encoder)
618
+ else:
619
+ _, new_concept_cfg = {}, {}
620
+ logging.info('------Step 3: no new embedding, skip merging token embedding------')
621
+
622
+ # step 2: construct reparameterized text_encoder
623
+ if any([item is not None for item in text_encoder_list]):
624
+ logging.info('------Step 4: merge text encoder------')
625
+ new_text_encoder_weights = merge_text_encoder(
626
+ concept_list, optimize_textenc_iters, new_concept_cfg, tokenizer,
627
+ text_encoder, text_encoder_list, device)
628
+
629
+ # update the merged state_dict in text_encoder
630
+ text_encoder_state_dict = text_encoder.state_dict()
631
+ text_encoder_state_dict.update(new_text_encoder_weights)
632
+ text_encoder.load_state_dict(text_encoder_state_dict)
633
+ else:
634
+ new_text_encoder_weights = {}
635
+ logging.info('------Step 4: no new text encoder, skip merging text encoder------')
636
+
637
+
638
+ # step 3: merge unet (k,v in crosskv-attention) params, since they only receive input from text-encoder
639
+
640
+ if any([item is not None for item in unet_crosskv_list]):
641
+ logging.info('------Step 5: merge kv of cross-attention in unet------')
642
+ new_kv_weights = merge_kv_in_cross_attention(
643
+ concept_list, optimize_textenc_iters, new_concept_cfg,
644
+ tokenizer, text_encoder, unet, unet_crosskv_list, device)
645
+ # update the merged state_dict in kv of crosskv-attention in Unet
646
+ unet_state_dict = unet.state_dict()
647
+ unet_state_dict.update(new_kv_weights)
648
+ unet.load_state_dict(unet_state_dict)
649
+ else:
650
+ new_kv_weights = {}
651
+ logging.info('------Step 5: no new kv of cross-attention in unet, skip merging kv------')
652
+
653
+ # step 4: merge unet (q,k,v in self-attention, q in crosskv-attention)
654
+ if any([item is not None for item in unet_spatial_attn_list]):
655
+ logging.info('------Step 6: merge spatial attention (q in cross-attention, qkv in self-attention) in unet------')
656
+ new_spatial_attention_weights = merge_spatial_attention(
657
+ concept_list, optimize_unet_iters, new_concept_cfg, tokenizer,
658
+ text_encoder, unet, unet_spatial_attn_list, test_scheduler, device)
659
+ unet_state_dict = unet.state_dict()
660
+ unet_state_dict.update(new_spatial_attention_weights)
661
+ unet.load_state_dict(unet_state_dict)
662
+ else:
663
+ new_spatial_attention_weights = {}
664
+ logging.info('------Step 6: no new spatial-attention in unet, skip merging spatial attention------')
665
+
666
+ checkpoint_save_path = f'{save_path}/combined_model_{suffix}'
667
+ pipe.save_pretrained(checkpoint_save_path)
668
+ with open(os.path.join(checkpoint_save_path, 'new_concept_cfg.json'), 'w') as json_file:
669
+ json.dump(new_concept_cfg, json_file)
670
+
671
+
672
+ def parse_args():
673
+ parser = argparse.ArgumentParser('', add_help=False)
674
+ parser.add_argument('--concept_cfg', help='json file for multi-concept', required=True, type=str)
675
+ parser.add_argument('--save_path', help='folder name to save optimized weights', required=True, type=str)
676
+ parser.add_argument('--suffix', help='suffix name', default='base', type=str)
677
+ parser.add_argument('--pretrained_models', required=True, type=str)
678
+ parser.add_argument('--optimize_unet_iters', default=50, type=int)
679
+ parser.add_argument('--optimize_textenc_iters', default=500, type=int)
680
+ return parser.parse_args()
681
+
682
+
683
+ if __name__ == '__main__':
684
+ args = parse_args()
685
+
686
+ # s1: set logger
687
+ exp_dir = f'{args.save_path}'
688
+ os.makedirs(exp_dir, exist_ok=True)
689
+ log_file = f'{exp_dir}/combined_model_{args.suffix}.log'
690
+ set_logger(log_file=log_file)
691
+ logging.info(args)
692
+
693
+ compose_concepts(args.concept_cfg,
694
+ args.optimize_textenc_iters,
695
+ args.optimize_unet_iters,
696
+ args.pretrained_models,
697
+ args.save_path,
698
+ args.suffix,
699
+ device='cuda')