Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel | |
# 加载模型和ControlNet | |
model_id = "CompVis/stable-diffusion-v1-4" | |
controlnet_id = "lllyasviel/sd-controlnet-canny" | |
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=ControlNetModel.from_pretrained(controlnet_id)) | |
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") | |
def is_valid_pil_image(image): | |
return isinstance(image, Image.Image) | |
def generate_image(control_image, additional_prompt, reference_image1, reference_image2): | |
# 检查输入图像是否为有效的 PIL 图像 | |
if not all(is_valid_pil_image(img) for img in [control_image, reference_image1, reference_image2]): | |
return "Error: One or more input images are not valid PIL images." | |
try: | |
# 第一步:使用 canny 模型生成初始图像 | |
initial_image = pipe(prompt=additional_prompt, control_image=control_image).images[0] | |
if not is_valid_pil_image(initial_image): | |
return "Error in Step 1: Generated image is not a valid PIL image." | |
except Exception as e: | |
return f"Error in Step 1: {str(e)}" | |
try: | |
# 第二步:使用第一张参考图像进行 reference_adain+attn 生成 | |
step2_image = pipe(prompt=additional_prompt, control_image=reference_image1).images[0] | |
if not is_valid_pil_image(step2_image): | |
return "Error in Step 2: Generated image is not a valid PIL image." | |
except Exception as e: | |
return f"Error in Step 2: {str(e)}" | |
try: | |
# 第三步:使用第二张参考图像进行 reference_adain+attn 生成 | |
final_image = pipe(prompt=additional_prompt, control_image=reference_image2).images[0] | |
if not is_valid_pil_image(final_image): | |
return "Error in Step 3: Generated image is not a valid PIL image." | |
return final_image | |
except Exception as e: | |
return f"Error in Step 3: {str(e)}" | |
# 定义Gradio接口 | |
interface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.components.Image(type="pil", label="Control Image"), | |
gr.components.Textbox(label="Additional Prompt"), | |
gr.components.Image(type="pil", label="Reference Image 1"), | |
gr.components.Image(type="pil", label="Reference Image 2") | |
], | |
outputs=gr.components.Image(type="pil", label="Generated Image"), | |
title="Stable Diffusion with Multi-Step ControlNet", | |
description="Generate images using a multi-step process with Stable Diffusion and ControlNet." | |
) | |
if __name__ == "__main__": | |
interface.launch() |