34 / app.py
tkkkkk's picture
Update app.py
a396e9b verified
import torch
from diffusers import StableDiffusionPipeline
import gradio as gr
# 选择一个文生图模型
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
def load_model():
"""
自动下载并加载文生图模型
"""
try:
# 使用 torch.float16 减少显存占用
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16
)
# 如果有GPU,移动模型到GPU
if torch.cuda.is_available():
pipe = pipe.to("cpu")
return pipe
except Exception as e:
print(f"模型加载错误: {e}")
return None
# 全局模型变量
generation_pipe = load_model()
def generate_image(prompt, negative_prompt="", steps=50, guidance_scale=7.5):
"""
根据文本提示生成图像
:param prompt: 图像生成提示词
:param negative_prompt: 负面提示词
:param steps: 生成步数
:param guidance_scale: 引导强度
:return: 生成的图像
"""
if generation_pipe is None:
return "模型加载失败"
try:
# 生成图像
image = generation_pipe(
prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance_scale
).images[0]
return image
except Exception as e:
print(f"图像生成错误: {e}")
return None
# 创建 Gradio 界面
def create_interface():
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="图像描述 (Prompt)"),
gr.Textbox(label="负面描述 (Negative Prompt)", value=""),
gr.Slider(minimum=10, maximum=100, value=50, label="生成步数"),
gr.Slider(minimum=1, maximum=15, value=7.5, label="引导强度")
],
outputs=gr.Image(label="生成的图像"),
title="文生图生成",
description="使用 Stable Diffusion 从文本生成图像"
)
return iface
# 启动应用
if __name__ == "__main__":
# 检查并提示模型加载状态
if generation_pipe is None:
print("警告:模型加载失败,应用可能无法正常工作")
# 启动 Gradio 界面
interface = create_interface()
interface.launch(
# 如果需要公开访问,取消注释下面的行
# share=True
)