fffiloni commited on
Commit
a40a1a9
1 Parent(s): 70c32aa

Upload 3 files

Browse files
Files changed (3) hide show
  1. blora_utils.py +46 -0
  2. inference.py +69 -0
  3. requirements.txt +11 -0
blora_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ BLOCKS = {
4
+ 'content': ['unet.up_blocks.0.attentions.0'],
5
+ 'style': ['unet.up_blocks.0.attentions.1'],
6
+ }
7
+
8
+
9
+ def is_belong_to_blocks(key, blocks):
10
+ try:
11
+ for g in blocks:
12
+ if g in key:
13
+ return True
14
+ return False
15
+ except Exception as e:
16
+ raise type(e)(f'failed to is_belong_to_block, due to: {e}')
17
+
18
+
19
+ def filter_lora(state_dict, blocks_):
20
+ try:
21
+ return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
22
+ except Exception as e:
23
+ raise type(e)(f'failed to filter_lora, due to: {e}')
24
+
25
+
26
+ def scale_lora(state_dict, alpha):
27
+ try:
28
+ return {k: v * alpha for k, v in state_dict.items()}
29
+ except Exception as e:
30
+ raise type(e)(f'failed to scale_lora, due to: {e}')
31
+
32
+
33
+ def get_target_modules(unet, blocks=None):
34
+ try:
35
+ if not blocks:
36
+ blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
37
+
38
+ attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
39
+ is_belong_to_blocks(attn_processor_name, blocks)]
40
+
41
+ target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
42
+ return target_modules
43
+ except Exception as e:
44
+ raise type(e)(f'failed to get_target_modules, due to: {e}')
45
+
46
+
inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
5
+
6
+ from blora_utils import BLOCKS, filter_lora, scale_lora
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument(
12
+ "--prompt", type=str, required=True, help="B-LoRA prompt"
13
+ )
14
+ parser.add_argument(
15
+ "--output_path", type=str, required=True, help="path to save the images"
16
+ )
17
+ parser.add_argument(
18
+ "--content_B_LoRA", type=str, default=None, help="path for the content B-LoRA"
19
+ )
20
+ parser.add_argument(
21
+ "--style_B_LoRA", type=str, default=None, help="path for the style B-LoRA"
22
+ )
23
+ parser.add_argument(
24
+ "--content_alpha", type=float, default=1., help="alpha parameter to scale the content B-LoRA weights"
25
+ )
26
+ parser.add_argument(
27
+ "--style_alpha", type=float, default=1., help="alpha parameter to scale the style B-LoRA weights"
28
+ )
29
+ parser.add_argument(
30
+ "--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
31
+ )
32
+ return parser.parse_args()
33
+
34
+
35
+ if __name__ == '__main__':
36
+ args = parse_args()
37
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
38
+ pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
39
+ vae=vae,
40
+ torch_dtype=torch.float16).to("cuda")
41
+
42
+ # Get Content B-LoRA SD
43
+ if args.content_B_LoRA is not None:
44
+ content_B_LoRA_sd, _ = pipeline.lora_state_dict(args.content_B_LoRA)
45
+ content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
46
+ content_B_LoRA = scale_lora(content_B_LoRA, args.content_alpha)
47
+ else:
48
+ content_B_LoRA = {}
49
+
50
+ # Get Style B-LoRA SD
51
+ if args.style_B_LoRA is not None:
52
+ style_B_LoRA_sd, _ = pipeline.lora_state_dict(args.style_B_LoRA)
53
+ style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
54
+ style_B_LoRA = scale_lora(style_B_LoRA, args.style_alpha)
55
+ else:
56
+ style_B_LoRA = {}
57
+
58
+ # Merge B-LoRAs SD
59
+ res_lora = {**content_B_LoRA, **style_B_LoRA}
60
+
61
+ # Load
62
+ pipeline.load_lora_into_unet(res_lora, None, pipeline.unet)
63
+
64
+ # Generate
65
+ images = pipeline(args.prompt, num_images_per_prompt=args.num_images_per_prompt).images
66
+
67
+ # Save
68
+ for i, img in enumerate(images):
69
+ img.save(f'{args.output_path}/{args.prompt}_{i}.jpg')
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ bitsandbytes==0.36.0.post2
3
+ datasets
4
+ diffusers==0.25.0
5
+ ftfy==6.1.1
6
+ huggingface-hub
7
+ Pillow==9.4.0
8
+ python-slugify==7.0.0
9
+ torch
10
+ torchvision
11
+ transformers