peft-lora-sd-dreambooth / inference.py
harkov000's picture
Duplicate from smangrul/peft-lora-sd-dreambooth
a1a0fc5
raw
history blame contribute delete
No virus
3.17 kB
from __future__ import annotations
import gc
import json
import pathlib
import sys
import gradio as gr
import PIL.Image
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraModel, LoraConfig, set_peft_model_state_dict
class InferencePipeline:
def __init__(self):
self.pipe = None
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.weight_path = None
def clear(self) -> None:
self.weight_path = None
del self.pipe
self.pipe = None
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def get_lora_weight_path(name: str) -> pathlib.Path:
curr_dir = pathlib.Path(__file__).parent
return curr_dir / name, curr_dir / f'{name.replace(".pt", "_config.json")}'
def load_and_set_lora_ckpt(self, pipe, weight_path, config_path, dtype):
with open(config_path, "r") as f:
lora_config = json.load(f)
lora_checkpoint_sd = torch.load(weight_path, map_location=self.device)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = {
k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
}
unet_config = LoraConfig(**lora_config["peft_config"])
pipe.unet = LoraModel(unet_config, pipe.unet)
set_peft_model_state_dict(pipe.unet, unet_lora_ds)
if "text_encoder_peft_config" in lora_config:
text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(self.device)
return pipe
def load_pipe(self, model_id: str, lora_filename: str) -> None:
weight_path, config_path = self.get_lora_weight_path(lora_filename)
if weight_path == self.weight_path:
return
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(self.device)
pipe = pipe.to(self.device)
pipe = self.load_and_set_lora_ckpt(pipe, weight_path, config_path, torch.float16)
self.pipe = pipe
def run(
self,
base_model: str,
lora_weight_name: str,
prompt: str,
negative_prompt: str,
seed: int,
n_steps: int,
guidance_scale: float,
) -> PIL.Image.Image:
if not torch.cuda.is_available():
raise gr.Error("CUDA is not available.")
self.load_pipe(base_model, lora_weight_name)
generator = torch.Generator(device=self.device).manual_seed(seed)
out = self.pipe(
prompt,
num_inference_steps=n_steps,
guidance_scale=guidance_scale,
generator=generator,
negative_prompt=negative_prompt if negative_prompt else None,
) # type: ignore
return out.images[0]