import gradio as gr import torch from PIL import Image from diffusers import StableDiffusionControlNetPipeline, ControlNetModel # 加载模型和ControlNet model_id = "CompVis/stable-diffusion-v1-4" controlnet_id = "lllyasviel/sd-controlnet-canny" pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=ControlNetModel.from_pretrained(controlnet_id)) pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") def is_valid_pil_image(image): return isinstance(image, Image.Image) def generate_image(control_image, additional_prompt, reference_image1, reference_image2): # 检查输入图像是否为有效的 PIL 图像 if not all(is_valid_pil_image(img) for img in [control_image, reference_image1, reference_image2]): return "Error: One or more input images are not valid PIL images." try: # 第一步:使用 canny 模型生成初始图像 initial_image = pipe(prompt=additional_prompt, control_image=control_image).images[0] if not is_valid_pil_image(initial_image): return "Error in Step 1: Generated image is not a valid PIL image." except Exception as e: return f"Error in Step 1: {str(e)}" try: # 第二步:使用第一张参考图像进行 reference_adain+attn 生成 step2_image = pipe(prompt=additional_prompt, control_image=reference_image1).images[0] if not is_valid_pil_image(step2_image): return "Error in Step 2: Generated image is not a valid PIL image." except Exception as e: return f"Error in Step 2: {str(e)}" try: # 第三步:使用第二张参考图像进行 reference_adain+attn 生成 final_image = pipe(prompt=additional_prompt, control_image=reference_image2).images[0] if not is_valid_pil_image(final_image): return "Error in Step 3: Generated image is not a valid PIL image." return final_image except Exception as e: return f"Error in Step 3: {str(e)}" # 定义Gradio接口 interface = gr.Interface( fn=generate_image, inputs=[ gr.components.Image(type="pil", label="Control Image"), gr.components.Textbox(label="Additional Prompt"), gr.components.Image(type="pil", label="Reference Image 1"), gr.components.Image(type="pil", label="Reference Image 2") ], outputs=gr.components.Image(type="pil", label="Generated Image"), title="Stable Diffusion with Multi-Step ControlNet", description="Generate images using a multi-step process with Stable Diffusion and ControlNet." ) if __name__ == "__main__": interface.launch()