MohamedRashad commited on
Commit
836dd96
·
1 Parent(s): c0ec201

Add prompt enhancement functionality and integrate Gradio client in app.py; update requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +68 -5
  2. requirements.txt +3 -1
app.py CHANGED
@@ -13,6 +13,7 @@ import re
13
  import random
14
  from pathlib import Path
15
  from typing import List
 
16
 
17
  import cv2
18
  import numpy as np
@@ -29,8 +30,10 @@ import spaces
29
  from models.infinity import Infinity
30
  from models.basic import *
31
  from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
 
32
 
33
  torch._dynamo.config.cache_size_limit = 64
 
34
 
35
  # Define a function to download weights if not present
36
  def download_infinity_weights(weights_path):
@@ -357,6 +360,60 @@ def load_transformer(vae, args):
357
  )
358
  return infinity
359
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  # Set up paths
361
  weights_path = Path(__file__).parent / 'weights'
362
  weights_path.mkdir(exist_ok=True)
@@ -380,7 +437,6 @@ args = argparse.Namespace(
380
  rope2d_normalized_by_hw=2,
381
  use_scale_schedule_embedding=0,
382
  sampling_per_bits=1,
383
- text_encoder_ckpt=str(weights_path / 'flan-t5-xl'),
384
  text_channels=2048,
385
  apply_spatial_patchify=0,
386
  h_div_w_template=1.000,
@@ -400,7 +456,7 @@ infinity = load_transformer(vae, args)
400
 
401
  # Define the image generation function
402
  @spaces.GPU
403
- def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt):
404
  try:
405
  args.prompt = prompt
406
  args.cfg = cfg
@@ -454,8 +510,8 @@ with gr.Blocks() as demo:
454
  # Prompt Settings
455
  gr.Markdown("### Prompt Settings")
456
  prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...")
457
- enable_positive_prompt = gr.Checkbox(label="Enable Positive Prompt", value=False, info="Enhance prompts with positive attributes for faces.")
458
-
459
  # Image Settings
460
  gr.Markdown("### Image Settings")
461
  with gr.Row():
@@ -477,10 +533,17 @@ with gr.Blocks() as demo:
477
  # Error Handling
478
  error_message = gr.Textbox(label="Error Message", visible=False)
479
 
 
 
 
 
 
 
 
480
  # Link the generate button to the image generation function
481
  generate_button.click(
482
  generate_image,
483
- inputs=[prompt, cfg, tau, h_div_w, seed, enable_positive_prompt],
484
  outputs=output_image
485
  )
486
 
 
13
  import random
14
  from pathlib import Path
15
  from typing import List
16
+ import json
17
 
18
  import cv2
19
  import numpy as np
 
30
  from models.infinity import Infinity
31
  from models.basic import *
32
  from utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates
33
+ from gradio_client import Client
34
 
35
  torch._dynamo.config.cache_size_limit = 64
36
+ client = Client("Qwen/Qwen2.5-72B-Instruct")
37
 
38
  # Define a function to download weights if not present
39
  def download_infinity_weights(weights_path):
 
360
  )
361
  return infinity
362
 
363
+ def enhance_prompt(prompt):
364
+ SYSTEM = """You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say.
365
+
366
+ When given a user prompt, your role is to transform it into a creative, detailed, and vivid image description. Additionally, you will assign a configuration value (`cfg`) based on the type of image.
367
+
368
+ ### Guidelines for Generating the Output:
369
+
370
+ 1. **Output Format:**
371
+ Your response must be in the following dictionary format:
372
+ ```json
373
+ {
374
+ "prompt": "<enhanced image description>",
375
+ "cfg": <cfg value>
376
+ }
377
+ ```
378
+
379
+ 2. **Enhancing the "prompt" field:**
380
+ - Use your creativity to transform short or vague prompts into highly detailed, descriptive, and imaginative image descriptions.
381
+ - Preserve the original intent and meaning of the user’s input.
382
+ - Focus on vivid imagery, sensory details, and emotional resonance in your descriptions.
383
+ - For particularly long user prompts (over 50 words), output them directly without refinement.
384
+ - Image descriptions must remain between 8-512 words. Any excess text will be ignored.
385
+ - If the user's request involves rendering specific text in the image, enclose that text in single quotation marks and prefix it with "the text".
386
+
387
+ 3. **Determining the "cfg" field:**
388
+ - If the image to be generated is likely to feature a clear face, set `"cfg": 1`.
389
+ - If the image does not prominently feature a face, set `"cfg": 3`.
390
+
391
+ 4. **Examples of Enhanced Prompts:**
392
+ - **User prompt:** "a tree"
393
+ **Enhanced prompt:** "A photo of a majestic oak tree stands proudly in the middle of a sunlit meadow, its branches stretching out like welcoming arms. The leaves shimmer in shades of vibrant green, casting dappled shadows on the soft grass below."
394
+ **Cfg:** `3`
395
+
396
+ - **User prompt:** "a cat by the window"
397
+ **Enhanced prompt:** "A serene scene of a fluffy tabby cat perched on the windowsill, gazing out at the golden hues of a sunset. The soft light filters through lace curtains, highlighting the cat’s delicate whiskers and its relaxed posture."
398
+ **Cfg:** `3`
399
+
400
+ 5. **Your Output:**
401
+ Always return a single dictionary containing both `"prompt"` and `"cfg"` fields. Avoid any additional commentary or explanations.
402
+
403
+ Don't write anything except the dictionary in the output. (Don't start with ```)
404
+ """
405
+ result = client.predict(
406
+ query=prompt,
407
+ history=[],
408
+ system=SYSTEM,
409
+ api_name="/model_chat"
410
+ )
411
+
412
+ dict_of_inputs = json.loads(result[1][-1][-1])
413
+ print(dict_of_inputs)
414
+
415
+ return gr.update(value=dict_of_inputs["prompt"]), gr.update(value=float(dict_of_inputs['cfg']))
416
+
417
  # Set up paths
418
  weights_path = Path(__file__).parent / 'weights'
419
  weights_path.mkdir(exist_ok=True)
 
437
  rope2d_normalized_by_hw=2,
438
  use_scale_schedule_embedding=0,
439
  sampling_per_bits=1,
 
440
  text_channels=2048,
441
  apply_spatial_patchify=0,
442
  h_div_w_template=1.000,
 
456
 
457
  # Define the image generation function
458
  @spaces.GPU
459
+ def generate_image(prompt, cfg, tau, h_div_w, seed, enable_positive_prompt=False):
460
  try:
461
  args.prompt = prompt
462
  args.cfg = cfg
 
510
  # Prompt Settings
511
  gr.Markdown("### Prompt Settings")
512
  prompt = gr.Textbox(label="Prompt", value="alien spaceship enterprise", placeholder="Enter your prompt here...")
513
+ enhance_prompt_button = gr.Button("Enhance Prompt", variant="secondary")
514
+
515
  # Image Settings
516
  gr.Markdown("### Image Settings")
517
  with gr.Row():
 
533
  # Error Handling
534
  error_message = gr.Textbox(label="Error Message", visible=False)
535
 
536
+ # Link the enhance prompt button to the prompt enhancement function
537
+ enhance_prompt_button.click(
538
+ enhance_prompt,
539
+ inputs=prompt,
540
+ outputs=[prompt, cfg],
541
+ )
542
+
543
  # Link the generate button to the image generation function
544
  generate_button.click(
545
  generate_image,
546
+ inputs=[prompt, cfg, tau, h_div_w, seed],
547
  outputs=output_image
548
  )
549
 
requirements.txt CHANGED
@@ -6,4 +6,6 @@ transformers
6
  argparse
7
  spaces
8
  torchvision
9
- timm
 
 
 
6
  argparse
7
  spaces
8
  torchvision
9
+ timm
10
+ gradio_client
11
+ imageio