Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import traceback
6
+ from typing import Optional
7
+
8
+ # Stable Diffusion模型相关设置
9
+ model_id: str = "runwayml/stable-diffusion-v1-5"
10
+ device: str = "cpu" # force CPU usage for compatibility
11
+
12
+ image_generator_pipe: Optional[StableDiffusionPipeline] = None
13
+
14
+ try:
15
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
16
+ image_generator_pipe = pipe.to(device)
17
+ except Exception as e:
18
+ print(f"Failed to load Stable Diffusion model: {e}")
19
+
20
+ # 提示词优化函数(简单版)
21
+ def optimize_prompt_simple(short_description: str) -> str:
22
+ optimized_prompt = f"Generate a high-quality, detailed image based on the following description: {short_description}"
23
+ return optimized_prompt
24
+
25
+ # 图像生成函数
26
+ def generate_image_sd(short_description: str,
27
+ negative_prompt: str,
28
+ guidance_scale: float,
29
+ num_inference_steps: int) -> Image.Image:
30
+ optimized_prompt = optimize_prompt_simple(short_description)
31
+
32
+ try:
33
+ with torch.no_grad():
34
+ if image_generator_pipe is None:
35
+ raise RuntimeError("Stable Diffusion pipeline is not available.")
36
+
37
+ output = image_generator_pipe(
38
+ prompt=optimized_prompt,
39
+ negative_prompt=negative_prompt,
40
+ guidance_scale=guidance_scale,
41
+ num_inference_steps=num_inference_steps
42
+ )
43
+ image = output.images[0] if output.images else None
44
+
45
+ if not image:
46
+ raise RuntimeError("No image was returned from the generation pipeline.")
47
+
48
+ return image
49
+ except Exception as e:
50
+ traceback.print_exc()
51
+ raise gr.Error(f"Image generation failed: {str(e)}")
52
+ # Gradio界面
53
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
54
+ with gr.Row():
55
+ with gr.Column(scale=1):
56
+ short_description = gr.Textbox(label="Short Description", placeholder="A magical treehouse in the sky")
57
+ optimized_prompt_display = gr.Textbox(label="Optimized Prompt", interactive=False)
58
+ neg_prompt = gr.Textbox(label="Negative Prompt", placeholder="blurry, distorted, watermark")
59
+ guidance = gr.Slider(1.0, 15.0, value=7.5, step=0.5, label="Guidance Scale")
60
+ steps = gr.Slider(10, 50, value=25, step=1, label="Inference Steps")
61
+ generate_btn = gr.Button("Generate Image")
62
+
63
+ with gr.Column(scale=1):
64
+ output_image = gr.Image(label="Generated Image", type="pil")
65
+
66
+ # 当用户输入简短描述时,自动优化提示词并显示
67
+ short_description.input(
68
+ fn=lambda x: optimize_prompt_simple(x),
69
+ inputs=short_description,
70
+ outputs=optimized_prompt_display
71
+ )
72
+
73
+ generate_btn.click(
74
+ fn=generate_image_sd,
75
+ inputs=[short_description, neg_prompt, guidance, steps],
76
+ outputs=output_image
77
+ )
78
+
79
+ if __name__ == "__main__":
80
+ if not image_generator_pipe:
81
+ print("WARNING: Stable Diffusion pipeline is not available. UI will launch, but generation will fail.")
82
+
83
+ demo.launch(server_name="0.0.0.0", server_port=7860)