linoyts HF Staff commited on
Commit
01b2d20
·
verified ·
1 Parent(s): 2afde1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -160
app.py CHANGED
@@ -18,8 +18,9 @@ from safetensors.torch import load_file
18
 
19
  import os
20
  import base64
21
- import BytesIO
22
  import json
 
23
 
24
  SYSTEM_PROMPT = '''
25
  # Edit Instruction Rewriter
@@ -187,208 +188,157 @@ def polish_prompt_hf(prompt, img_list):
187
  result = completion.choices[0].message.content
188
 
189
  # Try to extract JSON if present
190
- if '{"Rewritten"' in result:
191
  try:
192
- # Clean up the response
193
- result = result.replace('```json', '').replace('```', '')
194
- result_json = json.loads(result)
195
- polished_prompt = result_json.get('Rewritten', result)
196
- except:
 
 
 
197
  polished_prompt = result
198
  else:
199
  polished_prompt = result
200
 
201
  polished_prompt = polished_prompt.strip().replace("\n", " ")
 
202
  return polished_prompt
203
 
204
  except Exception as e:
205
  print(f"Error during API call to Hugging Face: {e}")
206
- # Fallback to original prompt if enhancement fails
207
  return prompt
208
-
209
- def next_scene_prompt(original_prompt, img_list):
 
 
 
 
 
 
210
  """
211
- Rewrites the prompt using a Hugging Face InferenceClient.
212
- Supports multiple images via img_list.
213
  """
214
- # Ensure HF_TOKEN is set
215
  api_key = os.environ.get("HF_TOKEN")
216
- if not api_key:
217
- print("Warning: HF_TOKEN not set. Falling back to original prompt.")
218
- return original_prompt
219
- prompt = f"{NEXT_SCENE_SYSTEM_PROMPT}"
220
- system_prompt = "you are a helpful assistant, you should provide useful answers to users."
221
  try:
222
- # Initialize the client
223
  client = InferenceClient(
224
- provider="nebius",
225
  api_key=api_key,
226
  )
227
-
228
- # Convert list of images to base64 data URLs
229
- image_urls = []
230
- if img_list is not None:
231
- # Ensure img_list is actually a list
232
- if not isinstance(img_list, list):
233
- img_list = [img_list]
234
-
235
- for img in img_list:
236
- image_url = None
237
- # If img is a PIL Image
238
- if hasattr(img, 'save'): # Check if it's a PIL Image
239
- buffered = BytesIO()
240
- img.save(buffered, format="PNG")
241
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
242
- image_url = f"data:image/png;base64,{img_base64}"
243
- # If img is already a file path (string)
244
- elif isinstance(img, str):
245
- with open(img, "rb") as image_file:
246
- img_base64 = base64.b64encode(image_file.read()).decode('utf-8')
247
- image_url = f"data:image/png;base64,{img_base64}"
248
- else:
249
- print(f"Warning: Unexpected image type: {type(img)}, skipping...")
250
- continue
251
-
252
- if image_url:
253
- image_urls.append(image_url)
254
-
255
- # Build the content array with text first, then all images
256
- content = [
257
- {
258
- "type": "text",
259
- "text": prompt
260
- }
261
- ]
262
 
263
- # Add all images to the content
264
- for image_url in image_urls:
265
- content.append({
266
- "type": "image_url",
267
- "image_url": {
268
- "url": image_url
269
- }
270
- })
271
-
272
- # Format the messages for the chat completions API
273
  messages = [
274
- {"role": "system", "content": system_prompt},
275
- {
276
- "role": "user",
277
- "content": content
278
- }
279
  ]
280
-
281
- # Call the API
 
 
 
 
282
  completion = client.chat.completions.create(
283
- model="Qwen/Qwen2.5-VL-72B-Instruct",
284
  messages=messages,
285
  )
286
 
287
- # Parse the response
288
- result = completion.choices[0].message.content
289
-
290
- # Try to extract JSON if present
291
- if '"Rewritten"' in result:
292
- try:
293
- # Clean up the response
294
- result = result.replace('```json', '').replace('```', '')
295
- result_json = json.loads(result)
296
- polished_prompt = result_json.get('Rewritten', result)
297
- except:
298
- polished_prompt = result
299
- else:
300
- polished_prompt = result
301
-
302
- polished_prompt = polished_prompt.strip().replace("\n", " ")
303
- return polished_prompt
304
 
305
  except Exception as e:
306
- print(f"Error during API call to Hugging Face: {e}")
307
- # Fallback to original prompt if enhancement fails
308
- return original_prompt
309
-
310
-
311
 
312
- def encode_image(pil_image):
313
- import io
314
- buffered = io.BytesIO()
315
- pil_image.save(buffered, format="PNG")
316
- return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  # --- Model Loading ---
319
  dtype = torch.bfloat16
320
  device = "cuda" if torch.cuda.is_available() else "cpu"
321
 
322
- pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509",
323
- transformer= QwenImageTransformer2DModel.from_pretrained("linoyts/Qwen-Image-Edit-Rapid-AIO",
324
- subfolder='transformer',
325
- torch_dtype=dtype,
326
- device_map='cuda'),torch_dtype=dtype).to(device)
327
-
328
- pipe.load_lora_weights(
329
- "lovis93/next-scene-qwen-image-lora-2509",
330
- weight_name="next-scene_lora-v2-3000.safetensors", adapter_name="next-scene"
331
- )
332
- pipe.set_adapters(["next-scene"], adapter_weights=[1.])
333
- pipe.fuse_lora(adapter_names=["next-scene"], lora_scale=1.)
334
- pipe.unload_lora_weights()
335
-
336
-
337
- # Apply the same optimizations from the first version
338
  pipe.transformer.__class__ = QwenImageTransformer2DModel
339
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
340
 
341
  # --- Ahead-of-time compilation ---
342
- optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
343
 
344
- # --- UI Constants and Helpers ---
345
  MAX_SEED = np.iinfo(np.int32).max
346
 
 
347
  def use_output_as_input(output_images):
348
- """Convert output images to input format for the gallery"""
349
- if output_images is None or len(output_images) == 0:
350
- return []
351
- return output_images
352
-
353
- def suggest_next_scene_prompt(images):
354
- pil_images = []
355
- if images is not None:
356
- for item in images:
357
- try:
358
- if isinstance(item[0], Image.Image):
359
- pil_images.append(item[0].convert("RGB"))
360
- elif isinstance(item[0], str):
361
- pil_images.append(Image.open(item[0]).convert("RGB"))
362
- elif hasattr(item, "name"):
363
- pil_images.append(Image.open(item.name).convert("RGB"))
364
- except Exception:
365
- continue
366
- if len(pil_images) > 0:
367
- prompt = next_scene_prompt("", pil_images)
368
- else:
369
- prompt = ""
370
- print("next scene prompt: ", prompt)
371
- return prompt
372
-
373
- # --- Main Inference Function (with hardcoded negative prompt) ---
374
- @spaces.GPU(duration=300)
 
 
 
 
 
 
375
  def infer(
376
- images,
377
- prompt,
378
- seed=42,
379
- randomize_seed=False,
380
- true_guidance_scale=1.0,
381
- num_inference_steps=4,
382
  height=None,
383
  width=None,
384
- rewrite_prompt=True,
385
- num_images_per_prompt=1,
386
- progress=gr.Progress(track_tqdm=True),
387
  ):
388
- """
389
- Generates an image using the local Qwen-Image diffusers pipeline.
390
- """
391
- # Hardcode the negative prompt as requested
392
  negative_prompt = " "
393
 
394
  if randomize_seed:
@@ -478,6 +428,22 @@ with gr.Blocks(css=css) as demo:
478
  result = gr.Gallery(label="Result", show_label=False, type="pil")
479
  # Add this button right after the result gallery - initially hidden
480
  use_output_btn = gr.Button("↗️ Use as input", variant="secondary", size="sm", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
  with gr.Row():
483
  prompt = gr.Text(
@@ -540,6 +506,7 @@ with gr.Blocks(css=css) as demo:
540
 
541
  # gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)
542
 
 
543
  gr.on(
544
  triggers=[run_button.click, prompt.submit],
545
  fn=infer,
@@ -554,15 +521,35 @@ with gr.Blocks(css=css) as demo:
554
  width,
555
  rewrite_prompt,
556
  ],
557
- outputs=[result, seed, use_output_btn], # Added use_output_btn to outputs
 
 
 
 
 
558
  )
559
 
560
- # Add the new event handler for the "Use Output as Input" button
561
  use_output_btn.click(
562
  fn=use_output_as_input,
563
  inputs=[result],
564
  outputs=[input_images]
565
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
 
567
  input_images.change(fn=suggest_next_scene_prompt, inputs=[input_images], outputs=[prompt])
568
 
 
18
 
19
  import os
20
  import base64
21
+ from io import BytesIO
22
  import json
23
+ import time
24
 
25
  SYSTEM_PROMPT = '''
26
  # Edit Instruction Rewriter
 
188
  result = completion.choices[0].message.content
189
 
190
  # Try to extract JSON if present
191
+ if '{"Rewritten"' in result or '"Rewritten"' in result:
192
  try:
193
+ result = result.replace('```json', '').replace('```', '').strip()
194
+ if result.startswith('{') and result.endswith('}'):
195
+ result_json = json.loads(result)
196
+ polished_prompt = result_json.get('Rewritten', result)
197
+ else:
198
+ polished_prompt = result
199
+ except Exception as e:
200
+ print(f"JSON parsing failed: {e}")
201
  polished_prompt = result
202
  else:
203
  polished_prompt = result
204
 
205
  polished_prompt = polished_prompt.strip().replace("\n", " ")
206
+ print(f"Polished prompt from HF: {polished_prompt}")
207
  return polished_prompt
208
 
209
  except Exception as e:
210
  print(f"Error during API call to Hugging Face: {e}")
 
211
  return prompt
212
+
213
+ def encode_image(img):
214
+ """Encode PIL Image to base64 string."""
215
+ buffer = BytesIO()
216
+ img.save(buffer, format="PNG")
217
+ return base64.b64encode(buffer.getvalue()).decode()
218
+
219
+ def suggest_next_scene_prompt_hf(img_list):
220
  """
221
+ Generate a cinematic "Next Scene" prompt using Hugging Face InferenceClient.
 
222
  """
 
223
  api_key = os.environ.get("HF_TOKEN")
224
+ if not api_key or not img_list:
225
+ return ""
226
+
 
 
227
  try:
 
228
  client = InferenceClient(
229
+ provider="cerebras",
230
  api_key=api_key,
231
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
 
 
 
 
 
 
 
 
 
 
233
  messages = [
234
+ {"role": "system", "content": NEXT_SCENE_SYSTEM_PROMPT},
235
+ {"role": "user", "content": []}
 
 
 
236
  ]
237
+
238
+ for img in img_list:
239
+ messages[1]["content"].append(
240
+ {"image": f"data:image/png;base64,{encode_image(img)}"})
241
+ messages[1]["content"].append({"text": "Generate a natural next scene prompt for this image."})
242
+
243
  completion = client.chat.completions.create(
244
+ model="Qwen/Qwen3-235B-A22B-Instruct-2507",
245
  messages=messages,
246
  )
247
 
248
+ result = completion.choices[0].message.content.strip()
249
+ print(f"Generated Next Scene prompt: {result}")
250
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  except Exception as e:
253
+ print(f"Error generating next scene prompt: {e}")
254
+ return ""
 
 
 
255
 
256
+ def suggest_next_scene_prompt(images):
257
+ """
258
+ Wrapper function to generate next scene prompt from image gallery.
259
+ """
260
+ if not images:
261
+ return ""
262
+
263
+ pil_images = []
264
+ for item in images:
265
+ try:
266
+ if isinstance(item[0], Image.Image):
267
+ pil_images.append(item[0].convert("RGB"))
268
+ elif isinstance(item[0], str):
269
+ pil_images.append(Image.open(item[0]).convert("RGB"))
270
+ elif hasattr(item, "name"):
271
+ pil_images.append(Image.open(item.name).convert("RGB"))
272
+ except Exception as e:
273
+ print(f"Error processing image: {e}")
274
+ continue
275
+
276
+ if not pil_images:
277
+ return ""
278
+
279
+ return suggest_next_scene_prompt_hf(pil_images)
280
 
281
  # --- Model Loading ---
282
  dtype = torch.bfloat16
283
  device = "cuda" if torch.cuda.is_available() else "cpu"
284
 
285
+ pipe = QwenImageEditPlusPipeline.from_pretrained("Phr00t/Qwen-Image-Edit-Rapid-AIO", torch_dtype=dtype).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  pipe.transformer.__class__ = QwenImageTransformer2DModel
287
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
288
 
289
  # --- Ahead-of-time compilation ---
290
+ optimize_pipeline_(pipe, image=Image.new("RGB", (1024, 1024)), prompt="prompt")
291
 
292
+ # --- Constants ---
293
  MAX_SEED = np.iinfo(np.int32).max
294
 
295
+ # --- Helper Functions ---
296
  def use_output_as_input(output_images):
297
+ """Convert the first image from the result gallery to input format."""
298
+ if output_images and len(output_images) > 0:
299
+ # output_images is a list of images
300
+ first_image = output_images[0]
301
+ # Return in the format expected by the Gallery: list of tuples
302
+ return [first_image]
303
+ return None
304
+
305
+ def update_history(new_images, history):
306
+ """Updates the history gallery with new images."""
307
+ time.sleep(0.5) # Small delay to ensure images are ready
308
+ if history is None:
309
+ history = []
310
+ if new_images is not None and len(new_images) > 0:
311
+ # Convert to list if needed
312
+ if not isinstance(history, list):
313
+ history = list(history) if history else []
314
+ # Add all new images to the beginning of history
315
+ for img in new_images:
316
+ history.insert(0, img)
317
+ # Keep only the last 20 images in history
318
+ history = history[:20]
319
+ return history
320
+
321
+ def use_history_as_input(evt: gr.SelectData):
322
+ """Sets the selected history image as the new input image."""
323
+ # evt.value contains the selected image
324
+ if evt.value is not None:
325
+ return [evt.value]
326
+ return None
327
+
328
+ # --- Inference Function ---
329
+ @spaces.GPU
330
  def infer(
331
+ images,
332
+ prompt,
333
+ seed=42,
334
+ randomize_seed=False,
335
+ true_guidance_scale=1.0,
336
+ num_inference_steps=8,
337
  height=None,
338
  width=None,
339
+ rewrite_prompt=False,
340
+ num_images_per_prompt=1
 
341
  ):
 
 
 
 
342
  negative_prompt = " "
343
 
344
  if randomize_seed:
 
428
  result = gr.Gallery(label="Result", show_label=False, type="pil")
429
  # Add this button right after the result gallery - initially hidden
430
  use_output_btn = gr.Button("↗️ Use as input", variant="secondary", size="sm", visible=False)
431
+
432
+ # Add history section
433
+ gr.Markdown("---")
434
+ with gr.Row():
435
+ gr.Markdown("### 📜 History")
436
+ clear_history_button = gr.Button("🗑️ Clear History", size="sm", variant="stop")
437
+
438
+ history_gallery = gr.Gallery(
439
+ label="Click any image to use as input",
440
+ columns=4,
441
+ rows=2,
442
+ object_fit="contain",
443
+ height="auto",
444
+ interactive=False,
445
+ show_label=True
446
+ )
447
 
448
  with gr.Row():
449
  prompt = gr.Text(
 
506
 
507
  # gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)
508
 
509
+ # Main generation events
510
  gr.on(
511
  triggers=[run_button.click, prompt.submit],
512
  fn=infer,
 
521
  width,
522
  rewrite_prompt,
523
  ],
524
+ outputs=[result, seed, use_output_btn],
525
+ ).then(
526
+ fn=update_history,
527
+ inputs=[result, history_gallery],
528
+ outputs=history_gallery,
529
+ show_api=False
530
  )
531
 
532
+ # Add the event handler for the "Use Output as Input" button
533
  use_output_btn.click(
534
  fn=use_output_as_input,
535
  inputs=[result],
536
  outputs=[input_images]
537
  )
538
+
539
+ # History gallery select handler
540
+ history_gallery.select(
541
+ fn=use_history_as_input,
542
+ outputs=[input_images],
543
+ show_api=False
544
+ )
545
+
546
+ # Clear history button
547
+ clear_history_button.click(
548
+ fn=lambda: [],
549
+ inputs=None,
550
+ outputs=history_gallery,
551
+ show_api=False
552
+ )
553
 
554
  input_images.change(fn=suggest_next_scene_prompt, inputs=[input_images], outputs=[prompt])
555