TextTo3DScene / app.py
CantonMonkey
test ControlNet with a random generated dummy img
b49d3bd
# import torch
# from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
# from PIL import Image
# import gradio as gr
# # 自动选择设备
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device}")
# # 加载 ControlNet 模型
# controlnet = ControlNetModel.from_pretrained(
# "lllyasviel/sd-controlnet-canny", torch_dtype=torch.float32
# )
# # 加载 Stable Diffusion + ControlNet
# pipe = StableDiffusionControlNetPipeline.from_pretrained(
# "CompVis/stable-diffusion-v1-4",
# controlnet=controlnet,
# torch_dtype=torch.float32
# ).to(device)
# # CPU 下节省显存
# pipe.enable_attention_slicing()
# # 文生图生成函数
# def generate_image(prompt, num_steps=20, height=256, width=256):
# """
# prompt: str, 文本描述
# num_steps: int, 推理步数(CPU 可少一些)
# height, width: int, 输出图像分辨率
# """
# image = pipe(prompt, num_inference_steps=num_steps, height=height, width=width).images[0]
# return image
# # Gradio 界面
# interface = gr.Interface(
# fn=generate_image,
# inputs=[
# gr.Textbox(label="Prompt", placeholder="Enter text prompt here..."),
# gr.Slider(5, 50, value=20, step=1, label="Inference Steps"),
# gr.Slider(128, 512, value=256, step=64, label="Height"),
# gr.Slider(128, 512, value=256, step=64, label="Width"),
# ],
# outputs=gr.Image(type="pil"),
# title="Text2Image Demo (v1-4 + ControlNet)",
# description="Generate images from text using Stable Diffusion v1-4 + ControlNet (CPU/GPU compatible)"
# )
# interface.launch()
import torch
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
from PIL import Image
import gradio as gr
import numpy as np
# 自动选择设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 加载 ControlNet 模型
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", torch_dtype=torch.float32
)
# 加载 Stable Diffusion + ControlNet
pipe = StableDiffusionControlNetPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
controlnet=controlnet,
torch_dtype=torch.float32
).to(device)
# CPU 下节省显存
pipe.enable_attention_slicing()
# 文生图生成函数
def generate_image(prompt, num_steps=20, height=256, width=256):
dummy_image = Image.fromarray(np.zeros((height, width, 3), dtype=np.uint8))
image = pipe(prompt, num_inference_steps=num_steps, height=height, width=width,image=dummy_image).images[0]
return image
# Gradio 界面
interface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter text prompt here..."),
gr.Slider(5, 50, value=20, step=1, label="Inference Steps"),
gr.Slider(128, 512, value=256, step=64, label="Height"),
gr.Slider(128, 512, value=256, step=64, label="Width"),
],
outputs=gr.Image(type="pil"),
title="Text2Image Demo (v1-4 + ControlNet)",
description="Generate images from text using Stable Diffusion v1-4 + ControlNet"
)
interface.launch()