Edison's picture
Refactor: Improve error handling and add image loading from URLs in app.py
6a7f883
raw
history blame
2.64 kB
import gradio as gr
import torch
from PIL import Image
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
# 加载模型和ControlNet
model_id = "CompVis/stable-diffusion-v1-4"
controlnet_id = "lllyasviel/sd-controlnet-canny"
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=ControlNetModel.from_pretrained(controlnet_id))
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
def is_valid_pil_image(image):
return isinstance(image, Image.Image)
def generate_image(control_image, additional_prompt, reference_image1, reference_image2):
# 检查输入图像是否为有效的 PIL 图像
if not all(is_valid_pil_image(img) for img in [control_image, reference_image1, reference_image2]):
return "Error: One or more input images are not valid PIL images."
try:
# 第一步:使用 canny 模型生成初始图像
initial_image = pipe(prompt=additional_prompt, control_image=control_image).images[0]
if not is_valid_pil_image(initial_image):
return "Error in Step 1: Generated image is not a valid PIL image."
except Exception as e:
return f"Error in Step 1: {str(e)}"
try:
# 第二步:使用第一张参考图像进行 reference_adain+attn 生成
step2_image = pipe(prompt=additional_prompt, control_image=reference_image1).images[0]
if not is_valid_pil_image(step2_image):
return "Error in Step 2: Generated image is not a valid PIL image."
except Exception as e:
return f"Error in Step 2: {str(e)}"
try:
# 第三步:使用第二张参考图像进行 reference_adain+attn 生成
final_image = pipe(prompt=additional_prompt, control_image=reference_image2).images[0]
if not is_valid_pil_image(final_image):
return "Error in Step 3: Generated image is not a valid PIL image."
return final_image
except Exception as e:
return f"Error in Step 3: {str(e)}"
# 定义Gradio接口
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.components.Image(type="pil", label="Control Image"),
gr.components.Textbox(label="Additional Prompt"),
gr.components.Image(type="pil", label="Reference Image 1"),
gr.components.Image(type="pil", label="Reference Image 2")
],
outputs=gr.components.Image(type="pil", label="Generated Image"),
title="Stable Diffusion with Multi-Step ControlNet",
description="Generate images using a multi-step process with Stable Diffusion and ControlNet."
)
if __name__ == "__main__":
interface.launch()