chore: implement presets in main app

#3
Files changed (1) hide show
  1. app_local.py +116 -143
app_local.py CHANGED
@@ -6,6 +6,8 @@ import spaces
6
  from PIL import Image
7
  from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler
8
  from diffusers.utils import is_xformers_available
 
 
9
  import os
10
  import sys
11
  import re
@@ -85,7 +87,6 @@ Please provide the rewritten instruction in a clean `json` format as:
85
  }
86
  '''
87
 
88
-
89
  def extract_json_response(model_output: str) -> str:
90
  """Extract rewritten instruction from potentially messy JSON output"""
91
  # Remove code block markers first
@@ -94,19 +95,15 @@ def extract_json_response(model_output: str) -> str:
94
  # Find the JSON portion in the output
95
  start_idx = model_output.find('{')
96
  end_idx = model_output.rfind('}')
97
-
98
  # Fix the condition - check if brackets were found
99
  if start_idx == -1 or end_idx == -1 or start_idx >= end_idx:
100
  print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}")
101
  return None
102
-
103
  # Expand to the full object including outer braces
104
  end_idx += 1 # Include the closing brace
105
  json_str = model_output[start_idx:end_idx]
106
-
107
  # Handle potential markdown or other formatting
108
  json_str = json_str.strip()
109
-
110
  # Try to parse JSON directly first
111
  try:
112
  data = json.loads(json_str)
@@ -119,7 +116,6 @@ def extract_json_response(model_output: str) -> str:
119
  json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
120
  # Try parsing again
121
  data = json.loads(json_str)
122
-
123
  # Extract rewritten prompt from possible key variations
124
  possible_keys = [
125
  "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent",
@@ -128,45 +124,36 @@ def extract_json_response(model_output: str) -> str:
128
  for key in possible_keys:
129
  if key in data:
130
  return data[key].strip()
131
-
132
  # Try nested path
133
  if "Response" in data and "Rewritten" in data["Response"]:
134
  return data["Response"]["Rewritten"].strip()
135
-
136
  # Handle nested JSON objects (additional protection)
137
  if isinstance(data, dict):
138
  for value in data.values():
139
  if isinstance(value, dict) and "Rewritten" in value:
140
  return value["Rewritten"].strip()
141
-
142
  # Try to find any string value that looks like an instruction
143
  str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
144
  if str_values:
145
  return str_values[0].strip()
146
-
147
  except Exception as e:
148
  print(f"JSON parse error: {str(e)}")
149
  print(f"Model output was: {model_output}")
150
  return None
151
 
152
-
153
  def polish_prompt(original_prompt: str) -> str:
154
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
155
-
156
  # Format as Qwen chat
157
  messages = [
158
  {"role": "system", "content": SYSTEM_PROMPT_EDIT},
159
  {"role": "user", "content": original_prompt}
160
  ]
161
-
162
  text = rewriter_tokenizer.apply_chat_template(
163
  messages,
164
  tokenize=False,
165
  add_generation_prompt=True
166
  )
167
-
168
  model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
169
-
170
  with torch.no_grad():
171
  generated_ids = rewriter_model.generate(
172
  **model_inputs,
@@ -178,18 +165,14 @@ def polish_prompt(original_prompt: str) -> str:
178
  no_repeat_ngram_size=3,
179
  pad_token_id=rewriter_tokenizer.eos_token_id
180
  )
181
-
182
  # Extract and clean response
183
  enhanced = rewriter_tokenizer.decode(
184
  generated_ids[0][model_inputs.input_ids.shape[1]:],
185
  skip_special_tokens=True
186
  ).strip()
187
-
188
  print(f"Model raw output: {enhanced}") # Debug logging
189
-
190
  # Try to extract JSON content
191
  rewritten_prompt = extract_json_response(enhanced)
192
-
193
  if rewritten_prompt:
194
  # Clean up remaining artifacts
195
  rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
@@ -205,12 +188,10 @@ def polish_prompt(original_prompt: str) -> str:
205
  rewritten_prompt = enhanced
206
  else:
207
  rewritten_prompt = enhanced
208
-
209
  # Basic cleanup
210
  rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
211
  if ': ' in rewritten_prompt:
212
  rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
213
-
214
  return rewritten_prompt[:200] if rewritten_prompt else original_prompt
215
 
216
  # Scheduler configuration for Lightning
@@ -231,6 +212,7 @@ scheduler_config = {
231
  "use_karras_sigmas": False,
232
  }
233
 
 
234
  # Initialize scheduler with Lightning config
235
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
236
 
@@ -254,15 +236,7 @@ if is_xformers_available():
254
  else:
255
  print("xformers not available")
256
 
257
- # def unload_rewriter():
258
- # """Clear enhancement model from memory"""
259
- # global rewriter_tokenizer, rewriter_model
260
- # if rewriter_model:
261
- # del rewriter_tokenizer, rewriter_model
262
- # rewriter_tokenizer = None
263
- # rewriter_model = None
264
- # torch.cuda.empty_cache()
265
- # gc.collect()
266
  @spaces.GPU()
267
  def infer(
268
  image,
@@ -273,33 +247,28 @@ def infer(
273
  num_inference_steps=8,
274
  rewrite_prompt=True,
275
  num_images_per_prompt=1,
 
276
  progress=gr.Progress(track_tqdm=True),
277
  ):
278
  """Image editing endpoint with optimized prompt handling"""
279
-
280
  # Resize image to max 1024px on longest side
281
  def resize_image(pil_image, max_size=1024):
282
  """Resize image to maximum dimension of 1024px while maintaining aspect ratio"""
283
  try:
284
  if pil_image is None:
285
  return pil_image
286
-
287
  width, height = pil_image.size
288
  max_dimension = max(width, height)
289
-
290
  if max_dimension <= max_size:
291
  return pil_image # No resize needed
292
-
293
  # Calculate new dimensions maintaining aspect ratio
294
  scale = max_size / max_dimension
295
  new_width = int(width * scale)
296
  new_height = int(height * scale)
297
-
298
  # Resize image
299
  resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
300
  print(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}")
301
  return resized_image
302
-
303
  except Exception as e:
304
  print(f"⚠️ Image resize failed: {e}")
305
  return pil_image # Return original if resize fails
@@ -310,7 +279,6 @@ def infer(
310
  try:
311
  if pil_image is None:
312
  return pil_image
313
-
314
  img_array = np.array(pil_image).astype(np.float32) / 255.0
315
  noise = np.random.normal(0, noise_level, img_array.shape)
316
  noisy_array = img_array + noise
@@ -322,96 +290,105 @@ def infer(
322
  except Exception as e:
323
  print(f"Warning: Could not add noise to image: {e}")
324
  return pil_image # Return original if noise addition fails
325
-
326
  # Resize input image first
327
  image = resize_image(image, max_size=1024)
328
-
329
  original_prompt = prompt
330
  prompt_info = ""
331
 
332
- # Handle prompt rewriting
333
- if rewrite_prompt:
334
- try:
335
- enhanced_instruction = polish_prompt(original_prompt)
336
- if enhanced_instruction and enhanced_instruction != original_prompt:
337
- prompt_info = (
338
- f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
339
- f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
340
- f"<p><strong>Original:</strong> {original_prompt}</p>"
341
- f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
342
- f"</div>"
343
- )
344
- prompt = enhanced_instruction
345
- else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  prompt_info = (
347
- f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>"
348
- f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>"
349
- f"<p>No enhancement applied or enhancement failed</p>"
350
  f"</div>"
351
  )
352
- except Exception as e:
353
- print(f"Prompt enhancement error: {str(e)}") # Debug logging
354
- gr.Warning(f"Prompt enhancement failed: {str(e)}")
355
  prompt_info = (
356
- f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
357
- f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
358
- f"<p>Using original prompt. Error: {str(e)[:100]}</p>"
359
  f"</div>"
360
  )
361
- else:
362
- prompt_info = (
363
- f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>"
364
- f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
365
- f"<p>{original_prompt}</p>"
366
- f"</div>"
367
- )
368
 
369
  # Set base seed for reproducibility
370
  base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED)
371
 
372
  try:
373
- # Generate images with variation for batch mode
374
- if num_images_per_prompt > 1:
375
- edited_images = []
376
- for i in range(num_images_per_prompt):
377
- # Create unique seed for each image
378
- generator = torch.Generator(device=device).manual_seed(base_seed + i*1000)
379
-
380
- # Add slight noise to the image for variation
381
- noisy_image = add_noise_to_image(image, noise_level=0.05 + i*0.003)
382
-
383
- # Slightly vary guidance scale
384
- varied_guidance = true_guidance_scale + random.uniform(-0.5, 0.5)
385
- varied_guidance = max(1.0, min(10.0, varied_guidance))
386
-
387
- # Generate single image with variations
388
- result = pipe(
389
- image=noisy_image,
390
- prompt=prompt,
391
- negative_prompt=" ",
392
- num_inference_steps=num_inference_steps,
393
- generator=generator,
394
- true_cfg_scale=varied_guidance,
395
- num_images_per_prompt=1
396
- ).images
397
- edited_images.extend(result)
398
- else:
399
- # Single image generation (unchanged)
400
- generator = torch.Generator(device=device).manual_seed(base_seed)
401
- edited_images = pipe(
402
- image=image,
403
- prompt=prompt,
404
  negative_prompt=" ",
405
  num_inference_steps=num_inference_steps,
406
  generator=generator,
407
- true_cfg_scale=true_guidance_scale,
408
- num_images_per_prompt=num_images_per_prompt
409
  ).images
 
 
 
410
 
411
  # Clear cache after generation
412
  if device == "cuda":
413
  torch.cuda.empty_cache()
414
  gc.collect()
 
415
  return edited_images, base_seed, prompt_info
416
  except Exception as e:
417
  # Clear cache on error
@@ -425,13 +402,14 @@ def infer(
425
  f"<p>{str(e)[:200]}</p>"
426
  f"</div>"
427
  )
428
-
 
429
  with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
430
  gr.Markdown("""
431
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
432
  <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1>
433
  <p>✨ 8-step inferencing with lightx2v's LoRA.</p>
434
- <p>📝 Local Prompt Enhancement, Batched Multi-image Generation</p>
435
  </div>
436
  """)
437
 
@@ -439,65 +417,72 @@ with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
439
  # Input Column
440
  with gr.Column(scale=1):
441
  input_image = gr.Image(
442
- label="Source Image",
443
- type="pil",
444
  height=300
445
  )
446
  prompt = gr.Textbox(
447
- label="Edit Instructions",
448
  placeholder="e.g. Replace the background with a beach sunset...",
449
  lines=2,
450
  max_lines=4
451
  )
452
 
 
453
  with gr.Row():
 
 
 
 
 
 
454
  rewrite_toggle = gr.Checkbox(
455
- label="Enable Prompt Enhancement",
456
  value=True,
457
  interactive=True
458
  )
459
  run_button = gr.Button(
460
- "Generate Edits",
461
- variant="primary",
462
  min_width=120
463
  )
464
 
465
  with gr.Accordion("Advanced Parameters", open=False):
466
  with gr.Row():
467
  seed = gr.Slider(
468
- label="Seed",
469
- minimum=0,
470
- maximum=MAX_SEED,
471
- step=1,
472
  value=42
473
  )
474
  randomize_seed = gr.Checkbox(
475
- label="Random Seed",
476
  value=True
477
  )
478
  with gr.Row():
479
  true_guidance_scale = gr.Slider(
480
- label="Guidance Scale",
481
- minimum=1.0,
482
- maximum=10.0,
483
- step=0.1,
484
  value=4.0
485
  )
486
  num_inference_steps = gr.Slider(
487
- label="Inference Steps",
488
- minimum=4,
489
- maximum=16,
490
- step=1,
491
  value=8
492
  )
493
  num_images_per_prompt = gr.Slider(
494
- label="Output Count",
495
- minimum=1,
496
- maximum=4,
497
- step=1,
498
  value=2
499
  )
500
-
501
  # Output Column
502
  with gr.Column(scale=2):
503
  result = gr.Gallery(
@@ -512,18 +497,6 @@ with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
512
  "Prompt details will appear after generation</div>"
513
  )
514
 
515
- # # Examples
516
- # gr.Examples(
517
- # examples=[
518
- # "Change the background scene to a rooftop bar at night",
519
- # "Transform to pixel art style with 8-bit graphics",
520
- # "Replace all text with 'Qwen AI' in futuristic font"
521
- # ],
522
- # inputs=[prompt],
523
- # label="Sample Instructions",
524
- # cache_examples=True
525
- # )
526
-
527
  # Set up processing
528
  inputs = [
529
  input_image,
@@ -533,9 +506,9 @@ with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
533
  true_guidance_scale,
534
  num_inference_steps,
535
  rewrite_toggle,
536
- num_images_per_prompt
 
537
  ]
538
-
539
  outputs = [result, seed, prompt_info]
540
 
541
  run_button.click(
@@ -543,11 +516,11 @@ with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
543
  inputs=inputs,
544
  outputs=outputs
545
  )
546
-
547
  prompt.submit(
548
  fn=infer,
549
  inputs=inputs,
550
  outputs=outputs
551
  )
552
 
 
553
  demo.queue(max_size=5).launch()
 
6
  from PIL import Image
7
  from diffusers import QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler
8
  from diffusers.utils import is_xformers_available
9
+ from presets import PRESETS, get_preset_choices, get_preset_info
10
+
11
  import os
12
  import sys
13
  import re
 
87
  }
88
  '''
89
 
 
90
  def extract_json_response(model_output: str) -> str:
91
  """Extract rewritten instruction from potentially messy JSON output"""
92
  # Remove code block markers first
 
95
  # Find the JSON portion in the output
96
  start_idx = model_output.find('{')
97
  end_idx = model_output.rfind('}')
 
98
  # Fix the condition - check if brackets were found
99
  if start_idx == -1 or end_idx == -1 or start_idx >= end_idx:
100
  print(f"No valid JSON structure found in output. Start: {start_idx}, End: {end_idx}")
101
  return None
 
102
  # Expand to the full object including outer braces
103
  end_idx += 1 # Include the closing brace
104
  json_str = model_output[start_idx:end_idx]
 
105
  # Handle potential markdown or other formatting
106
  json_str = json_str.strip()
 
107
  # Try to parse JSON directly first
108
  try:
109
  data = json.loads(json_str)
 
116
  json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
117
  # Try parsing again
118
  data = json.loads(json_str)
 
119
  # Extract rewritten prompt from possible key variations
120
  possible_keys = [
121
  "Rewritten", "rewritten", "Rewrited", "rewrited", "Rewrittent",
 
124
  for key in possible_keys:
125
  if key in data:
126
  return data[key].strip()
 
127
  # Try nested path
128
  if "Response" in data and "Rewritten" in data["Response"]:
129
  return data["Response"]["Rewritten"].strip()
 
130
  # Handle nested JSON objects (additional protection)
131
  if isinstance(data, dict):
132
  for value in data.values():
133
  if isinstance(value, dict) and "Rewritten" in value:
134
  return value["Rewritten"].strip()
 
135
  # Try to find any string value that looks like an instruction
136
  str_values = [v for v in data.values() if isinstance(v, str) and 10 < len(v) < 500]
137
  if str_values:
138
  return str_values[0].strip()
 
139
  except Exception as e:
140
  print(f"JSON parse error: {str(e)}")
141
  print(f"Model output was: {model_output}")
142
  return None
143
 
 
144
  def polish_prompt(original_prompt: str) -> str:
145
  """Enhanced prompt rewriting using original system prompt with JSON handling"""
 
146
  # Format as Qwen chat
147
  messages = [
148
  {"role": "system", "content": SYSTEM_PROMPT_EDIT},
149
  {"role": "user", "content": original_prompt}
150
  ]
 
151
  text = rewriter_tokenizer.apply_chat_template(
152
  messages,
153
  tokenize=False,
154
  add_generation_prompt=True
155
  )
 
156
  model_inputs = rewriter_tokenizer(text, return_tensors="pt").to(device)
 
157
  with torch.no_grad():
158
  generated_ids = rewriter_model.generate(
159
  **model_inputs,
 
165
  no_repeat_ngram_size=3,
166
  pad_token_id=rewriter_tokenizer.eos_token_id
167
  )
 
168
  # Extract and clean response
169
  enhanced = rewriter_tokenizer.decode(
170
  generated_ids[0][model_inputs.input_ids.shape[1]:],
171
  skip_special_tokens=True
172
  ).strip()
 
173
  print(f"Model raw output: {enhanced}") # Debug logging
 
174
  # Try to extract JSON content
175
  rewritten_prompt = extract_json_response(enhanced)
 
176
  if rewritten_prompt:
177
  # Clean up remaining artifacts
178
  rewritten_prompt = re.sub(r'(Replace|Change|Add) "(.*?)"', r'\1 \2', rewritten_prompt)
 
188
  rewritten_prompt = enhanced
189
  else:
190
  rewritten_prompt = enhanced
 
191
  # Basic cleanup
192
  rewritten_prompt = re.sub(r'\s\s+', ' ', rewritten_prompt).strip()
193
  if ': ' in rewritten_prompt:
194
  rewritten_prompt = rewritten_prompt.split(': ', 1)[-1].strip()
 
195
  return rewritten_prompt[:200] if rewritten_prompt else original_prompt
196
 
197
  # Scheduler configuration for Lightning
 
212
  "use_karras_sigmas": False,
213
  }
214
 
215
+
216
  # Initialize scheduler with Lightning config
217
  scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
218
 
 
236
  else:
237
  print("xformers not available")
238
 
239
+
 
 
 
 
 
 
 
 
240
  @spaces.GPU()
241
  def infer(
242
  image,
 
247
  num_inference_steps=8,
248
  rewrite_prompt=True,
249
  num_images_per_prompt=1,
250
+ preset_type=None, # New parameter for presets
251
  progress=gr.Progress(track_tqdm=True),
252
  ):
253
  """Image editing endpoint with optimized prompt handling"""
 
254
  # Resize image to max 1024px on longest side
255
  def resize_image(pil_image, max_size=1024):
256
  """Resize image to maximum dimension of 1024px while maintaining aspect ratio"""
257
  try:
258
  if pil_image is None:
259
  return pil_image
 
260
  width, height = pil_image.size
261
  max_dimension = max(width, height)
 
262
  if max_dimension <= max_size:
263
  return pil_image # No resize needed
 
264
  # Calculate new dimensions maintaining aspect ratio
265
  scale = max_size / max_dimension
266
  new_width = int(width * scale)
267
  new_height = int(height * scale)
 
268
  # Resize image
269
  resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
270
  print(f"📝 Image resized from {width}x{height} to {new_width}x{new_height}")
271
  return resized_image
 
272
  except Exception as e:
273
  print(f"⚠️ Image resize failed: {e}")
274
  return pil_image # Return original if resize fails
 
279
  try:
280
  if pil_image is None:
281
  return pil_image
 
282
  img_array = np.array(pil_image).astype(np.float32) / 255.0
283
  noise = np.random.normal(0, noise_level, img_array.shape)
284
  noisy_array = img_array + noise
 
290
  except Exception as e:
291
  print(f"Warning: Could not add noise to image: {e}")
292
  return pil_image # Return original if noise addition fails
293
+
294
  # Resize input image first
295
  image = resize_image(image, max_size=1024)
 
296
  original_prompt = prompt
297
  prompt_info = ""
298
 
299
+ # Handle preset batch generation
300
+ if preset_type and preset_type in PRESETS:
301
+ preset = PRESETS[preset_type]
302
+ batch_prompts = [f"{original_prompt}, {preset_prompt}" for preset_prompt in preset["prompts"]]
303
+ num_images_per_prompt = preset["count"]
304
+ prompt_info = (
305
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #2196F3; background: #f0f8ff'>"
306
+ f"<h4 style='margin-top: 0;'>🎨 Preset: {preset_type}</h4>"
307
+ f"<p>{preset['description']}</p>"
308
+ f"<p><strong>Base Prompt:</strong> {original_prompt}</p>"
309
+ f"</div>"
310
+ )
311
+ print(f"Using preset: {preset_type} with {len(batch_prompts)} variations")
312
+ else:
313
+ batch_prompts = [prompt] # Single prompt in list
314
+
315
+ # Handle regular prompt rewriting
316
+ if rewrite_prompt:
317
+ try:
318
+ enhanced_instruction = polish_prompt(original_prompt)
319
+ if enhanced_instruction and enhanced_instruction != original_prompt:
320
+ prompt_info = (
321
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #4CAF50; background: #f5f9fe'>"
322
+ f"<h4 style='margin-top: 0;'>🚀 Prompt Enhancement</h4>"
323
+ f"<p><strong>Original:</strong> {original_prompt}</p>"
324
+ f"<p><strong style='color:#2E7D32;'>Enhanced:</strong> {enhanced_instruction}</p>"
325
+ f"</div>"
326
+ )
327
+ batch_prompts = [enhanced_instruction]
328
+ else:
329
+ prompt_info = (
330
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF9800; background: #fff8f0'>"
331
+ f"<h4 style='margin-top: 0;'>📝 Prompt Enhancement</h4>"
332
+ f"<p>No enhancement applied or enhancement failed</p>"
333
+ f"</div>"
334
+ )
335
+ except Exception as e:
336
+ print(f"Prompt enhancement error: {str(e)}") # Debug logging
337
+ gr.Warning(f"Prompt enhancement failed: {str(e)}")
338
  prompt_info = (
339
+ f"<div style='margin:10px; padding:15px; border-radius:8px; border-left:4px solid #FF5252; background: #fef5f5'>"
340
+ f"<h4 style='margin-top: 0;'>⚠️ Enhancement Not Applied</h4>"
341
+ f"<p>Using original prompt. Error: {str(e)[:100]}</p>"
342
  f"</div>"
343
  )
344
+ else:
 
 
345
  prompt_info = (
346
+ f"<div style='margin:10px; padding:10px; border-radius:8px; background: #f8f9fa'>"
347
+ f"<h4 style='margin-top: 0;'>📝 Original Prompt</h4>"
348
+ f"<p>{original_prompt}</p>"
349
  f"</div>"
350
  )
 
 
 
 
 
 
 
351
 
352
  # Set base seed for reproducibility
353
  base_seed = seed if not randomize_seed else random.randint(0, MAX_SEED)
354
 
355
  try:
356
+ edited_images = []
357
+
358
+ # Generate images for each prompt in the batch
359
+ for i, current_prompt in enumerate(batch_prompts):
360
+ # Create unique seed for each image
361
+ generator = torch.Generator(device=device).manual_seed(base_seed + i*1000)
362
+
363
+ # Add slight noise to the image for variation (except for first image to maintain base)
364
+ if i == 0 and len(batch_prompts) == 1:
365
+ input_image = image
366
+ else:
367
+ input_image = add_noise_to_image(image, noise_level=0.01 + i*0.003)
368
+
369
+ # Slightly vary guidance scale for each image
370
+ varied_guidance = true_guidance_scale + random.uniform(-0.2, 0.2)
371
+ varied_guidance = max(1.0, min(10.0, varied_guidance))
372
+
373
+ # Generate single image
374
+ result = pipe(
375
+ image=input_image,
376
+ prompt=current_prompt,
 
 
 
 
 
 
 
 
 
 
377
  negative_prompt=" ",
378
  num_inference_steps=num_inference_steps,
379
  generator=generator,
380
+ true_cfg_scale=varied_guidance,
381
+ num_images_per_prompt=1
382
  ).images
383
+ edited_images.extend(result)
384
+
385
+ print(f"Generated image {i+1}/{len(batch_prompts)} with prompt: {current_prompt[:50]}...")
386
 
387
  # Clear cache after generation
388
  if device == "cuda":
389
  torch.cuda.empty_cache()
390
  gc.collect()
391
+
392
  return edited_images, base_seed, prompt_info
393
  except Exception as e:
394
  # Clear cache on error
 
402
  f"<p>{str(e)[:200]}</p>"
403
  f"</div>"
404
  )
405
+
406
+
407
  with gr.Blocks(title="Qwen Image Edit - Fast Lightning Mode w/ Batch") as demo:
408
  gr.Markdown("""
409
  <div style="text-align: center; background: linear-gradient(to right, #3a7bd5, #00d2ff); color: white; padding: 20px; border-radius: 8px;">
410
  <h1 style="margin-bottom: 5px;">⚡️ Qwen-Image-Edit Lightning</h1>
411
  <p>✨ 8-step inferencing with lightx2v's LoRA.</p>
412
+ <p>📝 Local Prompt Enhancement, Batched Multi-image Generation, 🎨 Preset Batches</p>
413
  </div>
414
  """)
415
 
 
417
  # Input Column
418
  with gr.Column(scale=1):
419
  input_image = gr.Image(
420
+ label="Source Image",
421
+ type="pil",
422
  height=300
423
  )
424
  prompt = gr.Textbox(
425
+ label="Edit Instructions",
426
  placeholder="e.g. Replace the background with a beach sunset...",
427
  lines=2,
428
  max_lines=4
429
  )
430
 
431
+ # Add preset dropdown
432
  with gr.Row():
433
+ preset_dropdown = gr.Dropdown(
434
+ choices=get_preset_choices(),
435
+ value=None,
436
+ label="Preset Batch Generation",
437
+ interactive=True
438
+ )
439
  rewrite_toggle = gr.Checkbox(
440
+ label="Enable Prompt Enhancement",
441
  value=True,
442
  interactive=True
443
  )
444
  run_button = gr.Button(
445
+ "Generate Edits",
446
+ variant="primary",
447
  min_width=120
448
  )
449
 
450
  with gr.Accordion("Advanced Parameters", open=False):
451
  with gr.Row():
452
  seed = gr.Slider(
453
+ label="Seed",
454
+ minimum=0,
455
+ maximum=MAX_SEED,
456
+ step=1,
457
  value=42
458
  )
459
  randomize_seed = gr.Checkbox(
460
+ label="Random Seed",
461
  value=True
462
  )
463
  with gr.Row():
464
  true_guidance_scale = gr.Slider(
465
+ label="Guidance Scale",
466
+ minimum=1.0,
467
+ maximum=10.0,
468
+ step=0.1,
469
  value=4.0
470
  )
471
  num_inference_steps = gr.Slider(
472
+ label="Inference Steps",
473
+ minimum=4,
474
+ maximum=16,
475
+ step=1,
476
  value=8
477
  )
478
  num_images_per_prompt = gr.Slider(
479
+ label="Output Count (Manual)",
480
+ minimum=1,
481
+ maximum=4,
482
+ step=1,
483
  value=2
484
  )
485
+
486
  # Output Column
487
  with gr.Column(scale=2):
488
  result = gr.Gallery(
 
497
  "Prompt details will appear after generation</div>"
498
  )
499
 
 
 
 
 
 
 
 
 
 
 
 
 
500
  # Set up processing
501
  inputs = [
502
  input_image,
 
506
  true_guidance_scale,
507
  num_inference_steps,
508
  rewrite_toggle,
509
+ num_images_per_prompt,
510
+ preset_dropdown # Add preset dropdown to inputs
511
  ]
 
512
  outputs = [result, seed, prompt_info]
513
 
514
  run_button.click(
 
516
  inputs=inputs,
517
  outputs=outputs
518
  )
 
519
  prompt.submit(
520
  fn=infer,
521
  inputs=inputs,
522
  outputs=outputs
523
  )
524
 
525
+
526
  demo.queue(max_size=5).launch()