sdpipe_webui / utils /functions.py
1lint
refactor code and fix cpu support
919fef8
import gradio as gr
import torch
import random
from PIL import Image
import os
import argparse
import shutil
import gc
import importlib
import json
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
)
from .inpaint_pipeline import SDInpaintPipeline as StableDiffusionInpaintPipelineLegacy
from .textual_inversion import main as run_textual_inversion
from .shared import default_scheduler, scheduler_dict, model_ids
_xformers_available = importlib.util.find_spec("xformers") is not None
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = 'cpu'
dtype = torch.float16 if device == "cuda" else torch.float32
low_vram_mode = False
tab_to_pipeline = {
1: StableDiffusionPipeline,
2: StableDiffusionImg2ImgPipeline,
3: StableDiffusionInpaintPipelineLegacy,
}
def load_pipe(model_id, scheduler_name, tab_index=1, pipe_kwargs="{}"):
global pipe, loaded_model_id
scheduler = scheduler_dict[scheduler_name]
pipe_class = tab_to_pipeline[tab_index]
# load new weights from disk only when changing model_id
if model_id != loaded_model_id:
pipe = pipe_class.from_pretrained(
model_id,
torch_dtype=dtype,
safety_checker=None,
requires_safety_checker=False,
scheduler=scheduler.from_pretrained(model_id, subfolder="scheduler"),
**json.loads(pipe_kwargs),
)
loaded_model_id = model_id
# if same model_id, instantiate new pipeline with same underlying pytorch objects to avoid reloading weights from disk
elif pipe_class != pipe.__class__ or not isinstance(pipe.scheduler, scheduler):
pipe.components["scheduler"] = scheduler.from_pretrained(
model_id, subfolder="scheduler"
)
pipe = pipe_class(**pipe.components)
if device == "cuda":
pipe = pipe.to(device)
if _xformers_available:
pipe.enable_xformers_memory_efficient_attention()
print("using xformers")
if low_vram_mode:
pipe.enable_attention_slicing()
print("using attention slicing to lower VRAM")
return pipe
pipe = None
loaded_model_id = ""
pipe = load_pipe(model_ids[0], default_scheduler)
def pad_image(image):
w, h = image.size
if w == h:
return image
elif w > h:
new_image = Image.new(image.mode, (w, w), (0, 0, 0))
new_image.paste(image, (0, (w - h) // 2))
return new_image
else:
new_image = Image.new(image.mode, (h, h), (0, 0, 0))
new_image.paste(image, ((h - w) // 2, 0))
return new_image
@torch.no_grad()
def generate(
model_name,
scheduler_name,
prompt,
guidance,
steps,
n_images=1,
width=512,
height=512,
seed=0,
image=None,
strength=0.5,
inpaint_image=None,
inpaint_strength=0.5,
inpaint_radio="",
neg_prompt="",
tab_index=1,
pipe_kwargs="{}",
progress=gr.Progress(track_tqdm=True),
):
if seed == -1:
seed = random.randint(0, 2147483647)
generator = torch.Generator(device).manual_seed(seed)
pipe = load_pipe(
model_id=model_name,
scheduler_name=scheduler_name,
tab_index=tab_index,
pipe_kwargs=pipe_kwargs,
)
status_message = f"Prompt: '{prompt}' | Seed: {seed} | Guidance: {guidance} | Scheduler: {scheduler_name} | Steps: {steps}"
if tab_index == 1:
status_message = "Text to Image " + status_message
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
num_inference_steps=int(steps),
guidance_scale=guidance,
width=width,
height=height,
generator=generator,
)
elif tab_index == 2:
status_message = "Image to Image " + status_message
print(image.size)
image = image.resize((width, height))
print(image.size)
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
image=image,
num_inference_steps=int(steps),
strength=strength,
guidance_scale=guidance,
generator=generator,
)
elif tab_index == 3:
status_message = "Inpainting " + status_message
init_image = inpaint_image["image"].resize((width, height))
mask = inpaint_image["mask"].resize((width, height))
result = pipe(
prompt,
negative_prompt=neg_prompt,
num_images_per_prompt=n_images,
image=init_image,
mask_image=mask,
num_inference_steps=int(steps),
strength=inpaint_strength,
preserve_unmasked_image=(
inpaint_radio == "preserve non-masked portions of image"
),
guidance_scale=guidance,
generator=generator,
)
else:
return None, f"Unhandled tab index: {tab_index}"
return result.images, status_message
# based on lvkaokao/textual-inversion-training
def train_textual_inversion(
model_name,
scheduler_name,
type_of_thing,
files,
concept_word,
init_word,
text_train_steps,
text_train_bsz,
text_learning_rate,
progress=gr.Progress(track_tqdm=True),
):
if device == "cpu":
raise gr.Error("Textual inversion training not supported on CPU")
pipe = load_pipe(
model_id=model_name,
scheduler_name=scheduler_name,
tab_index=1,
)
pipe.disable_xformers_memory_efficient_attention() # xformers handled by textual inversion script
concept_dir = "concept_images"
output_dir = "output_model"
training_resolution = 512
if os.path.exists(output_dir):
shutil.rmtree("output_model")
if os.path.exists(concept_dir):
shutil.rmtree("concept_images")
os.makedirs(concept_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
gc.collect()
torch.cuda.empty_cache()
if concept_word == "" or concept_word == None:
raise gr.Error("You forgot to define your concept prompt")
for j, file_temp in enumerate(files):
file = Image.open(file_temp.name)
image = pad_image(file)
image = image.resize((training_resolution, training_resolution))
extension = file_temp.name.split(".")[1]
image = image.convert("RGB")
image.save(f"{concept_dir}/{j+1}.{extension}", quality=100)
args_general = argparse.Namespace(
train_data_dir=concept_dir,
learnable_property=type_of_thing,
placeholder_token=concept_word,
initializer_token=init_word,
resolution=training_resolution,
train_batch_size=text_train_bsz,
gradient_accumulation_steps=1,
gradient_checkpointing=True,
mixed_precision="fp16",
use_bf16=False,
max_train_steps=int(text_train_steps),
learning_rate=text_learning_rate,
scale_lr=True,
lr_scheduler="constant",
lr_warmup_steps=0,
output_dir=output_dir,
)
try:
final_result = run_textual_inversion(pipe, args_general)
except Exception as e:
raise gr.Error(e)
pipe.text_encoder = pipe.text_encoder.eval().to(device, dtype=dtype)
pipe.unet = pipe.unet.eval().to(device, dtype=dtype)
gc.collect()
torch.cuda.empty_cache()
return (
f"Finished training! Check the {output_dir} directory for saved model weights"
)