DiffModels / app.py
SemaSci's picture
Исправил ошибку, добавил ip_adapter_image в список параметров функции infer
3e76729 verified
import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
from peft import PeftModel, LoraConfig
from rembg import remove
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
width=512,
height=512,
model_id=model_id_default,
seed=42,
guidance_scale=7.0,
lora_scale=1.0,
num_inference_steps=20,
controlnet_checkbox=False,
controlnet_strength=0.0,
controlnet_mode="edge_detection",
controlnet_image=None,
ip_adapter_checkbox=False,
ip_adapter_scale=0.0,
ip_adapter_image=None,
remove_bg=None,
progress=gr.Progress(track_tqdm=True),
):
ckpt_dir='./lora_pussinboots_logos'
unet_sub_dir = os.path.join(ckpt_dir, "unet")
#text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if model_id is None:
raise ValueError("Please specify the base model name or path")
generator = torch.Generator(device).manual_seed(seed)
params = {'prompt': prompt,
'negative_prompt': negative_prompt,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator
}
if controlnet_checkbox:
if controlnet_mode == "depth_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "pose_estimation":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "normal_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-normal",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "scribbles":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-scribble",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
else:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
params['image'] = controlnet_image
params['controlnet_conditioning_scale'] = float(controlnet_strength)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
#pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
# исправляем ошибку устанорвки lora_scale - меняем на параметр "cross_attention_kwargs"
# pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
params['cross_attention_kwargs'] = {"scale": float(lora_scale)}
#pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
if torch_dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
#pipe.text_encoder.half()
if ip_adapter_checkbox:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(ip_adapter_scale)
params['ip_adapter_image'] = ip_adapter_image
pipe.to(device)
image = pipe(**params).images[0]
# Если выбрано удаление фона
if remove_bg:
image = remove(image)
return image
examples = [
"Puss in Boots wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
"Cat wearing a sombrero crosses the Grand Canyon on a tightrope with a guitar.",
"A cat is playing a song called ""About the Cat"" on an accordion by the sea at sunset. The sun is quickly setting behind the horizon, and the light is fading.",
"A cat walks through the grass on the streets of an abandoned city. The camera view is always focused on the cat's face.",
"A young lady in a Russian embroidered kaftan is sitting on a beautiful carved veranda, holding a cup to her mouth and drinking tea from the cup. With her other hand, the girl holds a saucer. The cup and saucer are painted with gzhel. Next to the girl on the table stands a samovar, and steam can be seen above it.",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def controlnet_params(show_extra):
return gr.update(visible=show_extra)
with gr.Blocks(css=css, fill_height=True) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image demo")
with gr.Row():
model_id = gr.Textbox(
label="Model ID",
max_lines=1,
placeholder="Enter model id",
value=model_id_default,
)
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter your negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="ControlNet",
value=False
)
with gr.Column(visible=False) as controlnet_params:
controlnet_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
controlnet_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection",
"depth_map",
"pose_estimation",
"normal_map",
"scribbles"],
value="edge_detection",
max_choices=1
)
controlnet_image = gr.Image(
label="ControlNet condition image",
type="pil",
format="png"
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Row():
ip_adapter_checkbox = gr.Checkbox(
label="IPAdapter",
value=False
)
with gr.Column(visible=False) as ip_adapter_params:
ip_adapter_scale = gr.Slider(
label="IPAdapter scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
ip_adapter_image = gr.Image(
label="IPAdapter condition image",
type="pil"
)
ip_adapter_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=ip_adapter_checkbox,
outputs=ip_adapter_params
)
with gr.Accordion("Optional Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
# Удаление фона------------------------------------------------------------------------------------------------
# Checkbox для удаления фона
remove_bg = gr.Checkbox(
label="Remove Background",
value=False,
interactive=True
)
# -------------------------------------------------------------------------------------------------------------
gr.Examples(examples=examples, inputs=[prompt])
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
model_id,
seed,
guidance_scale,
lora_scale,
num_inference_steps,
controlnet_checkbox,
controlnet_strength,
controlnet_mode,
controlnet_image,
ip_adapter_checkbox,
ip_adapter_scale,
ip_adapter_image,
remove_bg,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()