diffusion / app.py
um235's picture
Update app.py
0526f37 verified
import gradio as gr
import numpy as np
import random
from peft import PeftModel, LoraConfig
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline
from diffusers import ControlNetModel
import torch
from PIL import Image
from rembg import remove
from diffusers import DiffusionPipeline, DDIMScheduler
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# ControlNet modes list with aliases
CONTROLNET_MODES = {
"Canny Edge Detection": "lllyasviel/control_v11p_sd15_canny",
"Pixel to Pixel": "lllyasviel/control_v11e_sd15_ip2p",
"Inpainting": "lllyasviel/control_v11p_sd15_inpaint",
"Multi-Level Line Segments": "lllyasviel/control_v11p_sd15_mlsd",
"Depth Estimation": "lllyasviel/control_v11f1p_sd15_depth",
"Surface Normal Estimation": "lllyasviel/control_v11p_sd15_normalbae",
"Image Segmentation": "lllyasviel/control_v11p_sd15_seg",
"Line Art Generation": "lllyasviel/control_v11p_sd15_lineart",
"Anime Line Art": "lllyasviel/control_v11p_sd15_lineart_anime",
"Human Pose Estimation": "lllyasviel/control_v11p_sd15_openpose",
"Scribble-Based Generation": "lllyasviel/control_v11p_sd15_scribble",
"Soft Edge Generation": "lllyasviel/control_v11p_sd15_softedge",
"Image Shuffling": "lllyasviel/control_v11e_sd15_shuffle",
"Image Tiling": "lllyasviel/control_v11f1e_sd15_tile",
}
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
model_id,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lscale=0.0,
controlnet_enabled=False,
control_strength=0.0,
control_mode=None,
control_image=None,
ip_adapter_enabled=False,
ip_adapter_scale=0.0,
ip_adapter_image=None,
d_bckg=False,
ddim_use=False,
distill_vae=False,
progress=gr.Progress(track_tqdm=True),
):
control_strength=float(control_strength)
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
if ip_adapter_enabled:
print("ip_adapter_image")
ip_adapter_image = ip_adapter_image.convert('RGB').resize((510, 510))
print("ip_adapter_image",ip_adapter_image.size)
pipe = None
if controlnet_enabled and control_image:
controlnet_model = ControlNetModel.from_pretrained(CONTROLNET_MODES.get(control_mode))
if model_id == "SD1.5 + lora Unet TextEncoder":
pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
elif model_id == "SD1.5 + lora Unet TextEncoder" or model_id == "SD1.5 + lora Unet":
pipe=StableDiffusionControlNetPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5",controlnet=controlnet_model)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
else:
pipe=StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet_model)
else:
if model_id == "SD1.5 + lora Unet TextEncoder":
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/vCat_v2", subfolder="unet")
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, "um235/vCat_v2", subfolder="text_encoder")
elif model_id == "SD1.5 + lora Unet":
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
pipe.unet = PeftModel.from_pretrained(pipe.unet, "um235/cartoon_cat_stickers")
else:
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
if ip_adapter_enabled:
print("ip_adapter_enabled",ip_adapter_enabled)
pipe.load_ip_adapter("h94/IP-Adapter",subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(ip_adapter_scale)
pipe.safety_checker = None
if ddim_use: pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
pipe = pipe.to(device)
image = pipe(
prompt=prompt,
image=control_image,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
cross_attention_kwargs={"scale": lscale},
controlnet_conditioning_scale=control_strength,
ip_adapter_image=ip_adapter_image,
).images[0]
if d_bckg:
image=remove(image)
return image, seed
examples = [
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a one white flower on its head, sitting in lotus pose on a yoga mat, with its paws pressed together in front of its chest in a prayer position, eyes closed, looking calm and peaceful.",
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, standing with a mischievous grin, one paw raised playfully, bright eyes full of energy, cheeky and fun, white background",
"Sticker VanillaCat. Cartoon-style cat with soft yellow fur and a white flower on its head, jumping mid-air with a surprised expression, wide eyes, and mouth open in excitement, paws stretched out, energetic and playful, forest background.",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def update_controlnet_visibility(controlnet_enabled):
return gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled), gr.update(visible=controlnet_enabled)
def update_ip_adapter_visibility(ip_adapter_enabled):
return gr.update(visible=ip_adapter_enabled), gr.update(visible=ip_adapter_enabled)
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # UM235 DIFFUSION Space")
model_id_input = gr.Dropdown(
label="Choose Model",
choices=[
"stable-diffusion-v1-5/stable-diffusion-v1-5",
"CompVis/stable-diffusion-v1-4",
"SD1.5 + lora Unet TextEncoder",
"SD1.5 + lora Unet"
],
value="SD1.5 + lora Unet TextEncoder",
show_label=True,
type="value",
)
with gr.Row():
lscale = gr.Slider(
label="Lora scale",
minimum=0,
maximum=2,
step=0.05,
value=0.85,
)
with gr.Row():
d_bckg=gr.Checkbox(label="Delete Background", value=False)
ddim_use=gr.Checkbox(label="Enable DDIMScheduler", value=False)
distill_vae=gr.Checkbox(label="Use tiny VAE with distill model", value=False)
# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True)
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
with gr.Accordion("ControlNet Settings", open=False):
controlnet_enabled = gr.Checkbox(label="Enable ControlNet", value=False)
with gr.Row():
control_strength = gr.Slider(
label="ControlNet scale",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.75,
visible=False,
)
control_mode = gr.Dropdown(
label="ControlNet Mode",
choices=list(CONTROLNET_MODES.keys()),
value="Canny Edge Detection",
visible=False,
)
control_image = gr.Image(label="ControlNet Image", type="pil", visible=False)
with gr.Accordion("IP-Adapter Settings", open=False):
ip_adapter_enabled = gr.Checkbox(label="Enable IP-Adapter", value=False)
with gr.Row():
ip_adapter_scale = gr.Slider(
label="IP-Adapter Scale",
minimum=0.0,
maximum=2.0,
step=0.05,
value=0.55,
visible=False,
)
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil", visible=False)
with gr.Row():
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
visible=True,
value="worst quality,low quality, low res, blurry, distortion, jpeg artifacts, backround"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=1274800826,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.3,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=36,
)
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
model_id_input,
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
lscale,
controlnet_enabled,
control_strength,
control_mode,
control_image,
ip_adapter_enabled,
ip_adapter_scale,
ip_adapter_image,
d_bckg,
ddim_use,
distill_vae
],
outputs=[result, seed],
)
controlnet_enabled.change(
fn=update_controlnet_visibility,
inputs=[controlnet_enabled],
outputs=[control_strength, control_mode, control_image],
)
ip_adapter_enabled.change(
fn=update_ip_adapter_visibility,
inputs=[ip_adapter_enabled],
outputs=[ip_adapter_scale, ip_adapter_image],
)
if __name__ == "__main__":
demo.launch()