alexander00001 commited on
Commit
a77e2f8
·
verified ·
1 Parent(s): f9d544d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +352 -112
app.py CHANGED
@@ -1,154 +1,394 @@
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
38
 
39
- generator = torch.Generator().manual_seed(seed)
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
  width=width,
47
  height=height,
48
  generator=generator,
49
  ).images[0]
50
 
51
- return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
 
 
53
 
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
 
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
90
  )
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
101
 
 
102
  with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
 
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
 
 
 
 
 
 
 
 
 
118
 
 
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
  minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
126
  )
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
 
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
 
 
 
153
  if __name__ == "__main__":
154
  demo.launch()
 
 
1
  import gradio as gr
 
 
 
 
 
2
  import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from PIL import Image
5
+ import datetime
6
+ import io
7
+ import json
8
+ import os
9
+ from typing import Optional
10
+
11
+ # ======================
12
+ # 配置区(你只需修改这里即可扩展)
13
+ # ======================
14
+
15
+ # 1. 基础模型
16
+ BASE_MODEL = "SG161222/RealisticVisionV6.0"
17
+
18
+ # 2. 固定LoRA(不可选,自动加载)
19
+ FIXED_LORAS = [
20
+ ("Lykon/epiCRealism_LoRA", 0.8), # 质量增强
21
+ ("latent-consistency/lora-dreamshaper", 0.7), # 姿势控制
22
+ ]
23
+
24
+ # 3. 风格模板(自动拼接到用户提示词前)
25
+ STYLE_PROMPTS = {
26
+ "None": "",
27
+ "Realistic": "photorealistic, ultra-detailed skin, natural lighting, 8k, professional photography, f/1.8, shallow depth of field, Canon EOS R5, ",
28
+ "Anime": "anime style, cel shading, vibrant colors, detailed eyes, studio ghibli, trending on pixiv, ",
29
+ "Comic": "comic book style, bold outlines, dynamic angles, comic panel, Marvel style, inked lines, ",
30
+ "Watercolor": "watercolor painting, soft brush strokes, translucent layers, artistic, painterly, paper texture, ",
31
+ }
32
 
33
+ # 4. 可选LoRA下拉菜单(用户可选1个,None表示清除)
34
+ OPTIONAL_LORAS = [
35
+ "None",
36
+ "Add Detail: https://huggingface.co/latent-consistency/lora-add-detail",
37
+ "Vintage Photo: https://huggingface.co/ckpt/LoRA-vintage-photo",
38
+ "Cinematic: https://huggingface.co/latent-consistency/lora-cinematic",
39
+ "Portrait Enhancer: https://huggingface.co/deforum/Portrait-Enhancer-LoRA",
40
+ "Soft Focus: https://huggingface.co/latent-consistency/lora-soft-focus",
41
+ ]
42
+
43
+ # 解析可选LoRA的名称和ID
44
+ OPTIONAL_LORA_MAP = {}
45
+ for item in OPTIONAL_LORAS:
46
+ if item != "None":
47
+ name, url = item.split(": ", 1)
48
+ OPTIONAL_LORA_MAP[name] = url
49
+ else:
50
+ OPTIONAL_LORA_MAP["None"] = None
51
+
52
+ # 默认参数
53
+ DEFAULT_SEED = -1
54
+ DEFAULT_WIDTH = 1024
55
+ DEFAULT_HEIGHT = 1024
56
+ DEFAULT_LORA_SCALE = 0.8
57
+ DEFAULT_STEPS = 30
58
+ DEFAULT_CFG = 7.5
59
+
60
+ # ======================
61
+ # 全局变量:延迟加载模型
62
+ # ======================
63
+ pipe = None
64
  device = "cuda" if torch.cuda.is_available() else "cpu"
65
+
66
+ def load_pipeline():
67
+ global pipe
68
+ if pipe is None:
69
+ print("🚀 Loading base model...")
70
+ pipe = StableDiffusionPipeline.from_pretrained(
71
+ BASE_MODEL,
72
+ torch_dtype=torch.float16,
73
+ safety_checker=None,
74
+ requires_safety_checker=False,
75
+ ).to(device)
76
+ pipe.enable_attention_slicing()
77
+ pipe.enable_vae_slicing()
78
+ pipe.enable_model_cpu_offload() # 适配ZeroGPU
79
+ print("✅ Base model loaded.")
80
+ return pipe
81
+
82
+ def unload_pipeline():
83
+ global pipe
84
+ if pipe is not None:
85
+ del pipe
86
+ torch.cuda.empty_cache()
87
+ pipe = None
88
+ print("🗑️ Pipeline unloaded.")
89
+
90
+ # ======================
91
+ # 主生成函数
92
+ # ======================
93
+ def generate_image(
94
+ prompt, negative_prompt, style, seed, width, height, optional_lora_name, lora_scale,
95
+ steps, cfg_scale
96
  ):
97
+ global pipe
98
+
99
+ # 加载模型(懒加载)
100
+ pipe = load_pipeline()
101
+
102
+ # 处理种子
103
+ if seed == -1:
104
+ seed = torch.randint(0, 2**32, (1,)).item()
105
+ generator = torch.Generator(device=device).manual_seed(seed)
106
 
107
+ # 拼接风格提示词
108
+ full_prompt = STYLE_PROMPTS[style] + prompt
109
+ full_negative_prompt = negative_prompt
110
 
111
+ # 加载固定LoRA(每次生成前都加载,确保状态正确)
112
+ for lora_id, scale in FIXED_LORAS:
113
+ pipe.load_lora_weights(lora_id, adapter_name=lora_id)
114
+ pipe.set_adapters([lora_id], adapter_weights=[scale])
115
+
116
+ # 加载可选LoRA(如果非None)
117
+ if optional_lora_name != "None":
118
+ lora_url = OPTIONAL_LORA_MAP[optional_lora_name]
119
+ pipe.load_lora_weights(lora_url, adapter_name=optional_lora_name)
120
+ pipe.set_adapters([lora_id for lora_id, _ in FIXED_LORAS] + [optional_lora_name],
121
+ adapter_weights=[scale for _, scale in FIXED_LORAS] + [lora_scale])
122
+ else:
123
+ # 清除所有可选LoRA,只保留固定
124
+ pipe.set_adapters([lora_id for lora_id, _ in FIXED_LORAS],
125
+ adapter_weights=[scale for _, scale in FIXED_LORAS])
126
+
127
+ # 生成图像
128
  image = pipe(
129
+ prompt=full_prompt,
130
+ negative_prompt=full_negative_prompt,
131
+ num_inference_steps=steps,
132
+ guidance_scale=cfg_scale,
133
  width=width,
134
  height=height,
135
  generator=generator,
136
  ).images[0]
137
 
138
+ # 生成元数据
139
+ metadata = {
140
+ "prompt": full_prompt,
141
+ "negative_prompt": full_negative_prompt,
142
+ "base_model": BASE_MODEL,
143
+ "fixed_loras": [lora_id for lora_id, _ in FIXED_LORAS],
144
+ "optional_lora": optional_lora_name if optional_lora_name != "None" else None,
145
+ "lora_scale": lora_scale,
146
+ "seed": seed,
147
+ "steps": steps,
148
+ "cfg_scale": cfg_scale,
149
+ "style": style,
150
+ "width": width,
151
+ "height": height,
152
+ "timestamp": datetime.datetime.now().isoformat()
153
+ }
154
 
155
+ # 生成文件名
156
+ timestamp = datetime.datetime.now().strftime("%y%m%d%H%M")
157
+ filename_base = f"{seed}-{timestamp}"
158
 
159
+ # 保存为WebP(高质量)
160
+ img_buffer = io.BytesIO()
161
+ image.save(img_buffer, format="WEBP", quality=95, method=6)
162
+ img_buffer.seek(0)
 
163
 
164
+ # 保存元数据为TXT
165
+ metadata_buffer = io.StringIO()
166
+ json.dump(metadata, metadata_buffer, indent=2, ensure_ascii=False)
167
+ metadata_buffer.seek(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # 返回:图像、元数据、文件名
170
+ return (
171
+ image,
172
+ json.dumps(metadata, indent=2, ensure_ascii=False),
173
+ f"{filename_base}.webp",
174
+ f"{filename_base}.txt",
175
+ img_buffer.getvalue(),
176
+ metadata_buffer.getvalue().encode('utf-8')
177
+ )
178
 
179
+ # ======================
180
+ # Gradio UI
181
+ # ======================
182
+ with gr.Blocks(
183
+ theme=gr.themes.Soft(
184
+ primary_hue="blue",
185
+ secondary_hue="green",
186
+ neutral_hue="slate",
187
+ ).set(
188
+ body_background_fill="linear-gradient(135deg, #1e40af, #059669)",
189
+ button_primary_background_fill="white",
190
+ button_primary_text_color="#1e40af",
191
+ input_background_fill="rgba(255,255,255,0.9)",
192
+ text_size="lg",
193
+ ),
194
+ css="""
195
+ body { font-family: 'Helvetica Neue', 'Segoe UI', 'Arial', sans-serif; }
196
+ .gr-button { font-family: 'Helvetica Neue', 'Arial', sans-serif; font-weight: 500; }
197
+ .gr-textarea { font-family: 'Consolas', 'Monaco', 'Courier New', monospace; }
198
+ """,
199
+ ) as demo:
200
+ gr.Markdown(
201
+ """
202
+ # 🎨 AI Photo Generator (RealisticVision + LoRA)
203
+ **PRO + ZeroGPU Optimized | Multi-LoRA | Style Templates | Metadata Export**
204
+ """
205
+ )
206
 
207
+ with gr.Row():
208
+ with gr.Column(scale=3):
209
+ # a. 提示词输入框
210
+ prompt_input = gr.Textbox(
211
+ label="Prompt (Positive)",
212
+ placeholder="A beautiful woman, golden hour, soft sunlight...",
213
+ lines=5,
214
+ max_lines=20,
215
+ elem_classes=["gr-textarea"]
216
  )
217
 
218
+ # b. 负提示词输入框
219
+ negative_prompt_input = gr.Textbox(
220
+ label="Negative Prompt",
221
+ placeholder="blurry, low quality, deformed, cartoon, anime, text, watermark...",
222
+ lines=5,
223
+ max_lines=20,
224
+ elem_classes=["gr-textarea"]
225
  )
226
 
227
+ # c. 风格选择(单选)
228
+ style_radio = gr.Radio(
229
+ choices=list(STYLE_PROMPTS.keys()),
230
+ label="Style",
231
+ value="Realistic",
232
+ elem_classes=["gr-radio"]
233
+ )
234
 
235
+ # d. 种子选择
236
  with gr.Row():
237
+ seed_input = gr.Slider(
238
+ minimum=-1,
239
+ maximum=99999999,
240
+ step=1,
241
+ value=DEFAULT_SEED,
242
+ label="Seed (-1 = Random)"
243
  )
244
+ seed_reset = gr.Button("Reset Seed")
245
 
246
+ # e. 宽度选择
247
+ with gr.Row():
248
+ width_input = gr.Slider(
249
+ minimum=512,
250
+ maximum=1536,
251
+ step=64,
252
+ value=DEFAULT_WIDTH,
253
+ label="Width"
254
+ )
255
+ width_reset = gr.Button("Reset Width")
256
+
257
+ # f. 高度选择
258
+ with gr.Row():
259
+ height_input = gr.Slider(
260
+ minimum=512,
261
+ maximum=1536,
262
+ step=64,
263
+ value=DEFAULT_HEIGHT,
264
+ label="Height"
265
  )
266
+ height_reset = gr.Button("Reset Height")
267
+
268
+ # g. LoRA选择(下拉)
269
+ optional_lora_dropdown = gr.Dropdown(
270
+ choices=list(OPTIONAL_LORA_MAP.keys()),
271
+ label="Optional LoRA",
272
+ value="None",
273
+ elem_classes=["gr-dropdown"]
274
+ )
275
 
276
+ # h. LoRA控制
277
  with gr.Row():
278
+ lora_scale_slider = gr.Slider(
 
279
  minimum=0.0,
280
+ maximum=1.5,
281
+ step=0.05,
282
+ value=DEFAULT_LORA_SCALE,
283
+ label="LoRA Scale"
284
  )
285
+ lora_reset = gr.Button("Reset LoRA Scale")
286
 
287
+ # i. 功能控制(Steps & CFG)
288
+ with gr.Row():
289
+ steps_slider = gr.Slider(
290
+ minimum=10,
291
+ maximum=100,
292
  step=1,
293
+ value=DEFAULT_STEPS,
294
+ label="Steps"
295
+ )
296
+ cfg_slider = gr.Slider(
297
+ minimum=1.0,
298
+ maximum=20.0,
299
+ step=0.5,
300
+ value=DEFAULT_CFG,
301
+ label="CFG Scale"
302
  )
303
+ gen_reset = gr.Button("Reset Generation")
304
+
305
+ # m. 生成按钮
306
+ generate_btn = gr.Button("✨ Generate Image", variant="primary", size="lg")
307
+
308
+ with gr.Column(scale=2):
309
+ # j. 图片显示区
310
+ image_output = gr.Image(label="Generated Image", height=768, format="webp")
311
+
312
+ # k. 元数据显示区
313
+ metadata_output = gr.Textbox(
314
+ label="Metadata (JSON)",
315
+ lines=12,
316
+ max_lines=20,
317
+ elem_classes=["gr-textarea"]
318
+ )
319
 
320
+ # l. 下载按钮(并列)
321
+ with gr.Row():
322
+ download_img_btn = gr.Button("⬇️ Download Image (WebP)")
323
+ download_meta_btn = gr.Button("⬇️ Download Metadata (TXT)")
324
+
325
+ # 隐藏文件输出(用于下载)
326
+ hidden_img_file = gr.File(visible=False)
327
+ hidden_meta_file = gr.File(visible=False)
328
+
329
+ # ======================
330
+ # 事件绑定
331
+ # ======================
332
+
333
+ # 重置种子
334
+ seed_reset.click(fn=lambda: -1, outputs=seed_input)
335
+ # 重置宽度
336
+ width_reset.click(fn=lambda: DEFAULT_WIDTH, outputs=width_input)
337
+ # 重置高度
338
+ height_reset.click(fn=lambda: DEFAULT_HEIGHT, outputs=height_input)
339
+ # 重置LoRA缩放
340
+ lora_reset.click(fn=lambda: DEFAULT_LORA_SCALE, outputs=lora_scale_slider)
341
+ # 重置生成参数
342
+ gen_reset.click(
343
+ fn=lambda: (DEFAULT_STEPS, DEFAULT_CFG),
344
+ outputs=[steps_slider, cfg_slider]
345
+ )
346
+
347
+ # 生成
348
+ generate_btn.click(
349
+ fn=generate_image,
350
  inputs=[
351
+ prompt_input, negative_prompt_input, style_radio,
352
+ seed_input, width_input, height_input,
353
+ optional_lora_dropdown, lora_scale_slider,
354
+ steps_slider, cfg_slider
 
 
 
 
355
  ],
356
+ outputs=[
357
+ image_output, metadata_output,
358
+ hidden_img_file, hidden_meta_file,
359
+ hidden_img_file, hidden_meta_file
360
+ ]
361
+ )
362
+
363
+ # 下载图片
364
+ download_img_btn.click(
365
+ fn=None,
366
+ inputs=[hidden_img_file],
367
+ outputs=None,
368
+ js="(f) => { const a = document.createElement('a'); a.href = f; a.download = f.split('/').pop(); document.body.appendChild(a); a.click(); document.body.removeChild(a); }"
369
+ )
370
+
371
+ # 下载元数据
372
+ download_meta_btn.click(
373
+ fn=None,
374
+ inputs=[hidden_meta_file],
375
+ outputs=None,
376
+ js="(f) => { const a = document.createElement('a'); a.href = f; a.download = f.split('/').pop(); document.body.appendChild(a); a.click(); document.body.removeChild(a); }"
377
+ )
378
+
379
+ # 设置文件下载(通过返回值触发)
380
+ generate_btn.change(
381
+ fn=lambda img_bytes, meta_bytes, img_name, meta_name: (
382
+ gr.File(value=io.BytesIO(img_bytes), label=img_name, visible=True),
383
+ gr.File(value=io.BytesIO(meta_bytes), label=meta_name, visible=True)
384
+ ),
385
+ inputs=[hidden_img_file, hidden_meta_file, hidden_img_file, hidden_meta_file],
386
+ outputs=[hidden_img_file, hidden_meta_file]
387
  )
388
 
389
+ # ======================
390
+ # 启动
391
+ # ======================
392
  if __name__ == "__main__":
393
  demo.launch()
394
+ ```