B-LoRa-Inference / inference.py
fffiloni's picture
Upload 3 files
a40a1a9 verified
import argparse
import torch
from diffusers import StableDiffusionXLPipeline, AutoencoderKL
from blora_utils import BLOCKS, filter_lora, scale_lora
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt", type=str, required=True, help="B-LoRA prompt"
)
parser.add_argument(
"--output_path", type=str, required=True, help="path to save the images"
)
parser.add_argument(
"--content_B_LoRA", type=str, default=None, help="path for the content B-LoRA"
)
parser.add_argument(
"--style_B_LoRA", type=str, default=None, help="path for the style B-LoRA"
)
parser.add_argument(
"--content_alpha", type=float, default=1., help="alpha parameter to scale the content B-LoRA weights"
)
parser.add_argument(
"--style_alpha", type=float, default=1., help="alpha parameter to scale the style B-LoRA weights"
)
parser.add_argument(
"--num_images_per_prompt", type=int, default=4, help="number of images per prompt"
)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipeline = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16).to("cuda")
# Get Content B-LoRA SD
if args.content_B_LoRA is not None:
content_B_LoRA_sd, _ = pipeline.lora_state_dict(args.content_B_LoRA)
content_B_LoRA = filter_lora(content_B_LoRA_sd, BLOCKS['content'])
content_B_LoRA = scale_lora(content_B_LoRA, args.content_alpha)
else:
content_B_LoRA = {}
# Get Style B-LoRA SD
if args.style_B_LoRA is not None:
style_B_LoRA_sd, _ = pipeline.lora_state_dict(args.style_B_LoRA)
style_B_LoRA = filter_lora(style_B_LoRA_sd, BLOCKS['style'])
style_B_LoRA = scale_lora(style_B_LoRA, args.style_alpha)
else:
style_B_LoRA = {}
# Merge B-LoRAs SD
res_lora = {**content_B_LoRA, **style_B_LoRA}
# Load
pipeline.load_lora_into_unet(res_lora, None, pipeline.unet)
# Generate
images = pipeline(args.prompt, num_images_per_prompt=args.num_images_per_prompt).images
# Save
for i, img in enumerate(images):
img.save(f'{args.output_path}/{args.prompt}_{i}.jpg')