diffusion / app.py
um235's picture
Update app.py
ca7d365 verified
raw
history blame
11.7 kB
import gradio as gr
import numpy as np
import random
from peft import PeftModel, LoraConfig
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline
from diffusers import ControlNetModel
import torch
from PIL import Image
from rembg import remove
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# ControlNet modes list with aliases
CONTROLNET_MODES = {
"Canny Edge Detection": "lllyasviel/control_v11p_sd15_canny",
"Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p",
"Inpainting": "lllyasviel/control_v11p_sd15_inpaint",
"Multi-Level Line Segments": "lllyasviel/control_v11p_sd15_mlsd",
"Depth Estimation": "lllyasviel/control_v11f1p_sd15_depth",
"Surface Normal Estimation": "lllyasviel/control_v11p_sd15_normalbae",
"Image Segmentation": "lllyasviel/control_v11p_sd15_seg",
"Line Art Generation": "lllyasviel/control_v11p_sd15_lineart",
"Anime Line Art": "lllyasviel/control_v11p_sd15_lineart_anime",
"Human Pose Estimation": "lllyasviel/control_v11p_sd15_openpose",
"Scribble-Based Generation": "lllyasviel/control_v11p_sd15_scribble",
"Soft Edge Generation": "lllyasviel/control_v11p_sd15_softedge",
"Image Shuffling": "lllyasviel/control_v11e_sd15_shuffle",
"Image Tiling": "lllyasviel/control_v11f1e_sd15_tile",
}
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lscale=0.0,
controlnet_enabled=False,
control_strength=0.0,
control_mode=None,
control_image=None,
ip_adapter_enabled=False,
ip_adapter_scale=0.0,
ip_adapter_image=None,
progress=gr.Progress(track_tqdm=True),
):
control_strength=float(control_strength)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if ip_adapter_enabled:
print("ip_adapter_image")
ip_adapter_image = ip_adapter_image.convert('RGB').resize((510, 510))
print("ip_adapter_image",ip_adapter_image.size)
pipe = None
if controlnet_enabled and control_image:
controlnet_model = ControlNetModel.from_pretrained(CONTROLNET_MODES.get(control_mode))
if model_id == "SD1.5 + lora Unet TextEncoder":
pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
elif model_id == "SD1.5 + lora Unet TextEncoder" or model_id == "SD1.5 + lora Unet":
pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
else:
pipe=StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet_model)
else:
if model_id == "SD1.5 + lora Unet TextEncoder":
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
elif model_id == "SD1.5 + lora Unet":
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
else:
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
if ip_adapter_enabled:
print("ip_adapter_enabled",ip_adapter_enabled)
pipe.load_ip_adapter("h94/IP-Adapter",subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(ip_adapter_scale)
pipe.safety_checker = None
pipe = pipe.to(device)
image = pipe(
prompt=prompt,
image=control_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lscale},
controlnet_conditioning_scale=control_strength,
ip_adapter_image=ip_adapter_image,
).images[0]
if d_bckg:
image=remove(image)
#pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
return image, seed
examples = [
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, sitting up with a relaxed expression, eyes half-closed, content and calm, casual pose, peaceful mood, white background.",
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, standing with a mischievous grin, one paw raised playfully, bright eyes full of energy, cheeky and fun, white background",
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, jumping mid-air with a surprised expression, wide eyes, and mouth open in excitement, paws stretched out, energetic and playful, forest background.",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def update_controlnet_visibility(controlnet_enabled):
return gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled)
def update_ip_adapter_visibility(ip_adapter_enabled):
return gr.update(visible=ip_adapter_enabled), gr.update(visible=ip_adapter_enabled)
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # UM235 DIFFUSION Space")
model_id_input = gr.Dropdown(
label="Choose Model",
choices=[
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
"SD1.5 + lora Unet TextEncoder",
"SD1.5 + lora Unet"
],
value="SD1.5 + lora Unet TextEncoder",
show_label=True,
type="value",
)
with gr.Row():
lscale = gr.Slider(
label="Lora scale",
minimum=0,
maximum=2,
step=0.05,
value=1,
)
with gr.Row():
d_bckg=gr.Checkbox(label="Delete Background", value=False)
ddim_use=gr.Checkbox(label="Enable DDIMScheduler", value=False)
distill_vae=gr.Checkbox(label="Use tiny VAE with distill model", value=True)
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Accordion("ControlNet Settings", open=False):
controlnet_enabled = gr.Checkbox(label="Enable ControlNet", value=False)
with gr.Row():
control_strength = gr.Slider(
label="ControlNet scale",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.75,
visible=False,
)
control_mode = gr.Dropdown(
label="ControlNet Mode",
choices=list(CONTROLNET_MODES.keys()),
value="Canny Edge Detection",
visible=False,
)
control_image = gr.Image(label="ControlNet Image", type="pil", visible=False)
with gr.Accordion("IP-Adapter Settings", open=False):
ip_adapter_enabled = gr.Checkbox(label="Enable IP-Adapter", value=False)
with gr.Row():
ip_adapter_scale = gr.Slider(
label="IP-Adapter Scale",
minimum=0.0,
maximum=2.0,
step=0.05,
value=1.0,
visible=False,
)
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
with gr.Row():
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
value="worst quality, normal quality, low quality, low res, blurry, distortion, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts,"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=235,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=36,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
model_id_input,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lscale,
controlnet_enabled,
control_strength,
control_mode,
control_image,
ip_adapter_enabled,
ip_adapter_scale,
ip_adapter_image,
],
outputs=[result, seed],
)
controlnet_enabled.change(
fn=update_controlnet_visibility,
inputs=[controlnet_enabled],
outputs=[control_strength, control_mode, control_image],
)
ip_adapter_enabled.change(
fn=update_ip_adapter_visibility,
inputs=[ip_adapter_enabled],
outputs=[ip_adapter_scale, ip_adapter_image],
)
if __name__ == "__main__":
demo.launch()