gokaygokay commited on
Commit
3052370
1 Parent(s): aba30b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -97
app.py CHANGED
@@ -1,119 +1,371 @@
1
  import spaces
2
- import gradio as gr
3
  import torch
4
  import random
5
- from diffusers import DiffusionPipeline
6
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # Initialize models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- dtype = torch.bfloat16
 
 
 
 
 
11
 
12
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
13
 
14
- # Initialize the base model and move it to GPU
15
- base_model = "black-forest-labs/FLUX.1-dev"
16
- pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.bfloat16, token=huggingface_token).to("cuda")
 
17
 
18
- # Load LoRA weights
19
- pipe.load_lora_weights("gokaygokay/Flux-Detailer-LoRA")
20
- pipe.fuse_lora()
21
 
22
- MAX_SEED = 2**32-1
23
 
24
- @spaces.GPU(duration=75)
25
- def generate_image(prompt, steps=28, seed=None, cfg_scale=3.5, width=1024, height=1024, lora_scale=1.0):
26
- if seed is None:
27
- seed = random.randint(0, MAX_SEED)
28
- generator = torch.Generator(device="cuda").manual_seed(seed)
 
 
 
29
 
30
- image = pipe(
31
- prompt=prompt,
32
- num_inference_steps=int(steps),
33
- guidance_scale=cfg_scale,
34
- width=int(width),
35
- height=int(height),
36
- generator=generator,
37
- joint_attention_kwargs={"scale": lora_scale},
38
- ).images[0]
39
- return image
40
-
41
- def run_lora(prompt, cfg_scale=3.5, steps=28, randomize_seed=True, seed=None, width=1024, height=1024, lora_scale=1.0):
42
- # Handle the case when only prompt is provided (for Examples)
43
- if isinstance(prompt, str) and all(param is None for param in [cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]):
44
- cfg_scale = 3.5
45
- steps = 28
46
- randomize_seed = True
47
- seed = None
48
- width = 1024
49
- height = 1024
50
- lora_scale = 1.0
51
-
52
- if randomize_seed or seed is None:
53
- seed = random.randint(0, MAX_SEED)
 
54
 
55
- image = generate_image(prompt, steps, seed, cfg_scale, width, height, lora_scale)
56
- return image, seed
57
-
58
- custom_css = """
59
- .input-group, .output-group {
60
- border: 1px solid #e0e0e0;
61
- border-radius: 10px;
62
- padding: 20px;
63
- margin-bottom: 20px;
64
- background-color: #f9f9f9;
65
- }
66
- .submit-btn {
67
- background-color: #2980b9 !important;
68
- color: white !important;
69
- }
70
- .submit-btn:hover {
71
- background-color: #3498db !important;
72
- }
73
- """
74
-
75
- title = """<h1 align="center">FLUX Creativity LoRA</h1>
76
- """
77
- examples = [
78
- ["anime, cartoon, Hyper-detailed, endearing anime girl, bathed in a vibrant, colorful psychedelic glow, wearing dazzling, holographic Liquid Metal outfit, in a cozy tatami studio", 0.5],
79
- ["extraterrestrial visage, close-up, highly intricate, ultra-detailed, full high definition", 0.5],
80
- ["a full body photo shot of a beautiful and breathtaking image of a ((Man) ) wearing a fully clothed casual witchy witch clothes with intricate details in the style of a reapers cloak, he is holding a long curved double edged ((scythe) ). This full body image is a one of a kind unique highly detailed with 8k sharp focus quality masterpiece, hyper detailed, extremely detailed", 0.5],
81
- ["schizophrenia attacks,go haywire, go crazy, hyper detailed, extremely detailed", 0.5],
82
- ]
83
-
84
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray"), css=custom_css) as app:
85
- gr.HTML(title)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  with gr.Row():
88
  with gr.Column(scale=1):
89
- prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Type your prompt here")
 
 
 
 
 
90
 
91
- with gr.Accordion("Advanced Settings", open=False):
92
- cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
93
- steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
94
- width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
95
- height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
96
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
97
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
98
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0)
 
 
99
 
100
- generate_button = gr.Button("Generate", variant="primary", elem_classes="submit-btn")
101
-
102
- with gr.Column(scale=1):
103
- result = gr.Image(label="Generated Image")
 
104
 
105
- inputs = [prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale]
106
- outputs = [result, seed]
 
107
 
108
- generate_button.click(fn=run_lora, inputs=inputs, outputs=outputs)
109
- prompt.submit(fn=run_lora, inputs=inputs, outputs=outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- gr.Examples(
112
- examples=examples,
113
- inputs=[prompt, lora_scale],
114
- outputs=[result, seed],
115
- fn=run_lora,
116
- cache_examples=True
 
 
 
 
 
 
117
  )
118
 
119
- app.launch(debug=True)
 
1
  import spaces
2
+ import os
3
  import torch
4
  import random
5
+ from huggingface_hub import snapshot_download
6
+ from diffusers import StableDiffusionXLPipeline, AutoencoderKL
7
+ from diffusers import (
8
+ EulerAncestralDiscreteScheduler,
9
+ DPMSolverMultistepScheduler,
10
+ DPMSolverSDEScheduler,
11
+ HeunDiscreteScheduler,
12
+ DDIMScheduler,
13
+ LMSDiscreteScheduler,
14
+ PNDMScheduler,
15
+ UniPCMultistepScheduler,
16
+ )
17
+ from diffusers.models.attention_processor import AttnProcessor2_0
18
+ import gradio as gr
19
+ from PIL import Image
20
+ import numpy as np
21
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
22
+ import requests
23
+ from RealESRGAN import RealESRGAN
24
+
25
+
26
+ import subprocess
27
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
28
+
29
+ def download_file(url, folder_path, filename):
30
+ if not os.path.exists(folder_path):
31
+ os.makedirs(folder_path)
32
+ file_path = os.path.join(folder_path, filename)
33
+
34
+ if os.path.isfile(file_path):
35
+ print(f"File already exists: {file_path}")
36
+ else:
37
+ response = requests.get(url, stream=True)
38
+ if response.status_code == 200:
39
+ with open(file_path, 'wb') as file:
40
+ for chunk in response.iter_content(chunk_size=1024):
41
+ file.write(chunk)
42
+ print(f"File successfully downloaded and saved: {file_path}")
43
+ else:
44
+ print(f"Error downloading the file. Status code: {response.status_code}")
45
+
46
+ # Download ESRGAN models
47
+ download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x2.pth?download=true", "models/upscalers/", "RealESRGAN_x2.pth")
48
+ download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealESRGAN_x4.pth?download=true", "models/upscalers/", "RealESRGAN_x4.pth")
49
+
50
+ # Download the model files
51
+ ckpt_dir_pony = snapshot_download(repo_id="Niggendar/autismmixSDXL_autismmixPony")
52
+ ckpt_dir_cyber = snapshot_download(repo_id="John6666/t-ponynai3-v6-sdxl")
53
+ ckpt_dir_stallion = snapshot_download(repo_id="John6666/prefect-pony-xl-v2-cleaned-style-sdxl")
54
+
55
+ # Load the models
56
+ vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
57
+ vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
58
+ vae_stallion = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_stallion, "vae"), torch_dtype=torch.float16)
59
+
60
+ pipe_pony = StableDiffusionXLPipeline.from_pretrained(
61
+ ckpt_dir_pony,
62
+ vae=vae_pony,
63
+ torch_dtype=torch.float16,
64
+ use_safetensors=True,
65
+ variant="fp16"
66
+ )
67
+ pipe_cyber = StableDiffusionXLPipeline.from_pretrained(
68
+ ckpt_dir_cyber,
69
+ vae=vae_cyber,
70
+ torch_dtype=torch.float16,
71
+ use_safetensors=True,
72
+ variant="fp16"
73
+ )
74
+ pipe_stallion = StableDiffusionXLPipeline.from_pretrained(
75
+ ckpt_dir_stallion,
76
+ vae=vae_stallion,
77
+ torch_dtype=torch.float16,
78
+ use_safetensors=True,
79
+ variant="fp16"
80
+ )
81
 
82
+ pipe_pony = pipe_pony.to("cuda")
83
+ pipe_cyber = pipe_cyber.to("cuda")
84
+ pipe_stallion = pipe_stallion.to("cuda")
85
+
86
+ pipe_pony.unet.set_attn_processor(AttnProcessor2_0())
87
+ pipe_cyber.unet.set_attn_processor(AttnProcessor2_0())
88
+ pipe_stallion.unet.set_attn_processor(AttnProcessor2_0())
89
+
90
+ # Define samplers
91
+ samplers = {
92
+ "Euler a": EulerAncestralDiscreteScheduler.from_config(pipe_pony.scheduler.config),
93
+ "DPM++ SDE Karras": DPMSolverSDEScheduler.from_config(pipe_pony.scheduler.config, use_karras_sigmas=True),
94
+ "Heun": HeunDiscreteScheduler.from_config(pipe_pony.scheduler.config),
95
+ # New samplers
96
+ "DPM++ 2M SDE Karras": DPMSolverMultistepScheduler.from_config(pipe_pony.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"),
97
+ "DPM++ 2M": DPMSolverMultistepScheduler.from_config(pipe_pony.scheduler.config),
98
+ "DDIM": DDIMScheduler.from_config(pipe_pony.scheduler.config),
99
+ "LMS": LMSDiscreteScheduler.from_config(pipe_pony.scheduler.config),
100
+ "PNDM": PNDMScheduler.from_config(pipe_pony.scheduler.config),
101
+ "UniPC": UniPCMultistepScheduler.from_config(pipe_pony.scheduler.config),
102
+ }
103
+
104
+ DEFAULT_POSITIVE_PREFIX = "Score_9 score_8_up score_7_up BREAK"
105
+ DEFAULT_POSITIVE_SUFFIX = "(masterpiece) very_aesthetic detailed_face cinematic footage"
106
+ DEFAULT_NEGATIVE_PREFIX = "score_1, score_2, score_3, text"
107
+ DEFAULT_NEGATIVE_SUFFIX = "nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn"
108
+
109
+ # Initialize Florence model
110
  device = "cuda" if torch.cuda.is_available() else "cpu"
111
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
112
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
113
+
114
+ # Prompt Enhancer
115
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
116
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
117
 
118
+ class LazyRealESRGAN:
119
+ def __init__(self, device, scale):
120
+ self.device = device
121
+ self.scale = scale
122
+ self.model = None
123
 
124
+ def load_model(self):
125
+ if self.model is None:
126
+ self.model = RealESRGAN(self.device, scale=self.scale)
127
+ self.model.load_weights(f'models/upscalers/RealESRGAN_x{self.scale}.pth', download=False)
128
 
129
+ def predict(self, img):
130
+ self.load_model()
131
+ return self.model.predict(img)
132
 
133
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
 
135
+ lazy_realesrgan_x2 = LazyRealESRGAN(device, scale=2)
136
+ lazy_realesrgan_x4 = LazyRealESRGAN(device, scale=4)
137
+
138
+ # Florence caption function
139
+ def florence_caption(image):
140
+ # Convert image to PIL if it's not already
141
+ if not isinstance(image, Image.Image):
142
+ image = Image.fromarray(image)
143
 
144
+ inputs = florence_processor(text="<DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
145
+ generated_ids = florence_model.generate(
146
+ input_ids=inputs["input_ids"],
147
+ pixel_values=inputs["pixel_values"],
148
+ max_new_tokens=1024,
149
+ early_stopping=False,
150
+ do_sample=False,
151
+ num_beams=3,
152
+ )
153
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
154
+ parsed_answer = florence_processor.post_process_generation(
155
+ generated_text,
156
+ task="<DETAILED_CAPTION>",
157
+ image_size=(image.width, image.height)
158
+ )
159
+ return parsed_answer["<DETAILED_CAPTION>"]
160
+
161
+ # Prompt Enhancer function
162
+ def enhance_prompt(input_prompt, model_choice):
163
+ if model_choice == "Medium":
164
+ result = enhancer_medium("Enhance the description: " + input_prompt)
165
+ enhanced_text = result[0]['summary_text']
166
+ else: # Long
167
+ result = enhancer_long("Enhance the description: " + input_prompt)
168
+ enhanced_text = result[0]['summary_text']
169
 
170
+ return enhanced_text
171
+
172
+ def upscale_image(image, scale):
173
+ # Ensure image is a PIL Image object
174
+ if not isinstance(image, Image.Image):
175
+ if isinstance(image, np.ndarray):
176
+ image = Image.fromarray(image)
177
+ else:
178
+ raise ValueError("Input must be a PIL Image or a numpy array")
179
+
180
+ if scale == 2:
181
+ return lazy_realesrgan_x2.predict(image)
182
+ elif scale == 4:
183
+ return lazy_realesrgan_x4.predict(image)
184
+ else:
185
+ return image
186
+
187
+ @spaces.GPU(duration=120)
188
+ def generate_image(model_choice, additional_positive_prompt, additional_negative_prompt, height, width, num_inference_steps,
189
+ guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler, clip_skip,
190
+ use_florence2, use_medium_enhancer, use_long_enhancer,
191
+ use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
192
+ use_upscaler, upscale_factor,
193
+ input_image=None, progress=gr.Progress(track_tqdm=True)):
194
+
195
+ # Select the appropriate pipe based on the model choice
196
+ if model_choice == "AutismMix SDXL":
197
+ pipe = pipe_pony
198
+ elif model_choice == "T-ponynai3":
199
+ pipe = pipe_cyber
200
+ else: # "Stallion Dreams Pony Realistic v1"
201
+ pipe = pipe_stallion
202
+
203
+ if use_random_seed:
204
+ seed = random.randint(0, 2**32 - 1)
205
+ else:
206
+ seed = int(seed) # Ensure seed is an integer
207
+
208
+ # Set the scheduler based on the selected sampler
209
+ pipe.scheduler = samplers[sampler]
210
+
211
+ # Set clip skip
212
+ pipe.text_encoder.config.num_hidden_layers -= (clip_skip - 1)
213
+
214
+ # Start with the default positive prompt prefix if enabled
215
+ full_positive_prompt = DEFAULT_POSITIVE_PREFIX + ", " if use_positive_prefix else ""
216
+
217
+ # Add Florence-2 caption if enabled and image is provided
218
+ if use_florence2 and input_image is not None:
219
+ florence2_caption = florence_caption(input_image)
220
+ florence2_caption = florence2_caption.lower().replace('.', ',')
221
+ additional_positive_prompt = f"{florence2_caption}, {additional_positive_prompt}" if additional_positive_prompt else florence2_caption
222
+
223
+ # Enhance only the additional positive prompt if enhancers are enabled
224
+ if additional_positive_prompt:
225
+ enhanced_prompt = additional_positive_prompt
226
+ if use_medium_enhancer:
227
+ medium_enhanced = enhance_prompt(enhanced_prompt, "Medium")
228
+ medium_enhanced = medium_enhanced.lower().replace('.', ',')
229
+ enhanced_prompt = f"{enhanced_prompt}, {medium_enhanced}"
230
+ if use_long_enhancer:
231
+ long_enhanced = enhance_prompt(enhanced_prompt, "Long")
232
+ long_enhanced = long_enhanced.lower().replace('.', ',')
233
+ enhanced_prompt = f"{enhanced_prompt}, {long_enhanced}"
234
+ full_positive_prompt += enhanced_prompt
235
+
236
+ # Add the default positive suffix if enabled
237
+ if use_positive_suffix:
238
+ full_positive_prompt += f", {DEFAULT_POSITIVE_SUFFIX}"
239
+
240
+ # Combine default negative prompt with additional negative prompt
241
+ full_negative_prompt = ""
242
+ if use_negative_prefix:
243
+ full_negative_prompt += f"{DEFAULT_NEGATIVE_PREFIX}, "
244
+ full_negative_prompt += additional_negative_prompt if additional_negative_prompt else ""
245
+ if use_negative_suffix:
246
+ full_negative_prompt += f", {DEFAULT_NEGATIVE_SUFFIX}"
247
 
248
+ try:
249
+ images = pipe(
250
+ prompt=full_positive_prompt,
251
+ negative_prompt=full_negative_prompt,
252
+ height=height,
253
+ width=width,
254
+ num_inference_steps=num_inference_steps,
255
+ guidance_scale=guidance_scale,
256
+ num_images_per_prompt=num_images_per_prompt,
257
+ generator=torch.Generator(pipe.device).manual_seed(seed)
258
+ ).images
259
+
260
+ if use_upscaler:
261
+ print("Upscaling images")
262
+ upscaled_images = []
263
+ for i, img in enumerate(images):
264
+ print(f"Upscaling image {i+1}")
265
+ if not isinstance(img, Image.Image):
266
+ print(f"Converting image {i+1} to PIL Image")
267
+ img = Image.fromarray(np.uint8(img))
268
+ upscaled_img = upscale_image(img, upscale_factor)
269
+ upscaled_images.append(upscaled_img)
270
+ images = upscaled_images
271
+
272
+ print("Returning results")
273
+ return images, seed, full_positive_prompt, full_negative_prompt
274
+ except Exception as e:
275
+ print(f"Error during image generation: {str(e)}")
276
+ import traceback
277
+ traceback.print_exc()
278
+ return None, seed, full_positive_prompt, full_negative_prompt
279
+
280
+ # Gradio interface
281
+ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
282
+ gr.HTML("""
283
+ <h1 align="center">Pony Realism / Cyber Realism / Stallion Dreams</h1>
284
+ <p align="center">
285
+ <a href="https://huggingface.co/Niggendar/autismmixSDXL_autismmixPony/" target="_blank">[AutismMix SDXL]</a>
286
+ <a href="https://huggingface.co/John6666/t-ponynai3-v6-sdxl" target="_blank">[T-ponynai3]</a>
287
+ <a href="https://huggingface.co/John6666/prefect-pony-xl-v2-cleaned-style-sdxl" target="_blank">[Prefect Pony XL]</a><br>
288
+ <a href="https://civitai.com/models/288584/autismmix-sdxl" target="_blank">[AutismMix SDXL civitai]</a>
289
+ <a href="https://civitai.com/models/317902/t-ponynai3" target="_blank">[T-ponynai3 civitai]</a>
290
+ <a href="https://civitai.com/models/439889/prefect-pony-xl" target="_blank">[Prefect Pony XL civitai]</a>
291
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
292
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
293
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance" target="_blank">[Prompt Enhancer Medium]</a>
294
+ </p>
295
+ """)
296
+
297
  with gr.Row():
298
  with gr.Column(scale=1):
299
+ model_choice = gr.Dropdown(
300
+ ["AutismMix SDXL", "T-ponynai3", "Prefect Pony XL"],
301
+ label="Model Choice",
302
+ value="AutismMix SDXL")
303
+ positive_prompt = gr.Textbox(label="Positive Prompt", placeholder="Add your positive prompt here")
304
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Add your negative prompt here")
305
 
306
+ with gr.Accordion("Advanced settings", open=False):
307
+ height = gr.Slider(512, 2048, 1024, step=64, label="Height")
308
+ width = gr.Slider(512, 2048, 1024, step=64, label="Width")
309
+ num_inference_steps = gr.Slider(20, 100, 30, step=1, label="Number of Inference Steps")
310
+ guidance_scale = gr.Slider(1, 20, 6, step=0.1, label="Guidance Scale")
311
+ num_images_per_prompt = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
312
+ use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
313
+ seed = gr.Number(label="Seed", value=0, precision=0)
314
+ sampler = gr.Dropdown(label="Sampler", choices=list(samplers.keys()), value="Euler a")
315
+ clip_skip = gr.Slider(1, 4, 2, step=1, label="Clip skip")
316
 
317
+ with gr.Accordion("Captioner and Enhancers", open=False):
318
+ input_image = gr.Image(label="Input Image for Florence-2 Captioner")
319
+ use_florence2 = gr.Checkbox(label="Use Florence-2 Captioner", value=False)
320
+ use_medium_enhancer = gr.Checkbox(label="Use Medium Prompt Enhancer", value=False)
321
+ use_long_enhancer = gr.Checkbox(label="Use Long Prompt Enhancer", value=False)
322
 
323
+ with gr.Accordion("Upscaler Settings", open=False):
324
+ use_upscaler = gr.Checkbox(label="Use Upscaler", value=False)
325
+ upscale_factor = gr.Radio(label="Upscale Factor", choices=[2, 4], value=2)
326
 
327
+ generate_btn = gr.Button("Generate Image")
328
+
329
+ with gr.Accordion("Prefix and Suffix Settings", open=True):
330
+ use_positive_prefix = gr.Checkbox(
331
+ label="Use Positive Prefix",
332
+ value=True,
333
+ info=f"Prefix: {DEFAULT_POSITIVE_PREFIX}"
334
+ )
335
+ use_positive_suffix = gr.Checkbox(
336
+ label="Use Positive Suffix",
337
+ value=True,
338
+ info=f"Suffix: {DEFAULT_POSITIVE_SUFFIX}"
339
+ )
340
+ use_negative_prefix = gr.Checkbox(
341
+ label="Use Negative Prefix",
342
+ value=True,
343
+ info=f"Prefix: {DEFAULT_NEGATIVE_PREFIX}"
344
+ )
345
+ use_negative_suffix = gr.Checkbox(
346
+ label="Use Negative Suffix",
347
+ value=True,
348
+ info=f"Suffix: {DEFAULT_NEGATIVE_SUFFIX}"
349
+ )
350
+
351
+ with gr.Column(scale=1):
352
+ output_gallery = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
353
+ seed_used = gr.Number(label="Seed Used")
354
+ full_positive_prompt_used = gr.Textbox(label="Full Positive Prompt Used")
355
+ full_negative_prompt_used = gr.Textbox(label="Full Negative Prompt Used")
356
 
357
+ generate_btn.click(
358
+ fn=generate_image,
359
+ inputs=[
360
+ model_choice, # Add this new input
361
+ positive_prompt, negative_prompt, height, width, num_inference_steps,
362
+ guidance_scale, num_images_per_prompt, use_random_seed, seed, sampler,
363
+ clip_skip, use_florence2, use_medium_enhancer, use_long_enhancer,
364
+ use_positive_prefix, use_positive_suffix, use_negative_prefix, use_negative_suffix,
365
+ use_upscaler, upscale_factor,
366
+ input_image
367
+ ],
368
+ outputs=[output_gallery, seed_used, full_positive_prompt_used, full_negative_prompt_used]
369
  )
370
 
371
+ demo.launch(debug=True)