girishwangikar commited on
Commit
c8e4067
·
verified ·
1 Parent(s): e98e685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -59
app.py CHANGED
@@ -22,14 +22,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
22
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
-
26
- MAX_IMAGE_SIZE = 1024
27
 
28
  # Few-shot examples
29
  few_shot_examples = [
30
  ("Create a birthday card for friend", "A vibrant birthday card with a colorful confetti background, featuring a large, playful 'Happy Birthday!' in the center. The card has a fun, festive theme with balloons, streamers, and a cupcake with a single lit candle. The message inside reads, 'Wishing you a day full of laughter and joy!'"),
31
  ("An educational infographic showing the stages of the water cycle with bright, engaging visuals.", "An educational infographic illustrating the water cycle. The diagram shows labeled stages including evaporation, condensation, precipitation, and collection, with arrows guiding the flow. The colors are bright and engaging, with clouds, raindrops, and a sun. The design is simple and clear, suitable for a classroom setting."),
32
-
33
  ]
34
 
35
  def generate_detailed_prompt(user_input):
@@ -38,19 +36,18 @@ def generate_detailed_prompt(user_input):
38
  Given a simple description, create an elaborate and detailed prompt that can be used to generate high-quality images.
39
  Your response should be concise and no longer than 3 sentences.
40
  Use the following examples as a guide for the level of detail and creativity expected:
41
-
42
  """ + "\n\n".join([f"Input: {input}\nOutput: {output}" for input, output in few_shot_examples]))
43
 
44
  human_message = HumanMessage(content=f"Generate a detailed image prompt based on this input, using no more than 3 sentences: {user_input}")
45
 
46
  response = llm([system_message, human_message])
47
  return response.content
48
-
49
  @spaces.GPU()
50
- def generate_image(prompt, seed=0, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
51
- if randomize_seed:
52
- seed = random.randint(0, MAX_SEED)
53
- generator = torch.Generator(device=device).manual_seed(seed if seed is not None else 0)
54
  image = pipe(
55
  prompt=prompt,
56
  width=width,
@@ -59,14 +56,10 @@ def generate_image(prompt, seed=0, randomize_seed=False, width=1024, height=1024
59
  generator=generator,
60
  guidance_scale=0.0
61
  ).images[0]
62
-
63
- return image, seed
64
 
 
65
 
66
  # Gradio UI setup
67
- import gradio as gr
68
-
69
- # Gradio UI
70
  css = """
71
  #col-container {
72
  margin: 0 auto;
@@ -87,13 +80,9 @@ css = """
87
  #result {
88
  margin-bottom: 20px;
89
  }
90
- #advanced-settings {
91
- margin-bottom: 20px;
92
- }
93
  """
94
 
95
  with gr.Blocks(css=css, theme='gradio/soft') as demo:
96
-
97
  with gr.Column(elem_id="col-container"):
98
  gr.Markdown("""
99
  # AI-Enhanced Image Generation
@@ -109,61 +98,23 @@ with gr.Blocks(css=css, theme='gradio/soft') as demo:
109
  container=False,
110
  elem_id="prompt"
111
  )
112
-
113
  run_button = gr.Button("Generate Image", scale=0)
114
 
115
  result = gr.Image(label="Result", show_label=False, elem_id="result")
116
 
117
- with gr.Accordion("Advanced Settings", open=False, elem_id="advanced-settings"):
118
- seed = gr.Slider(
119
- label="Seed",
120
- minimum=0,
121
- maximum=MAX_SEED,
122
- step=1,
123
- value=0,
124
- )
125
-
126
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
127
-
128
- with gr.Row():
129
- width = gr.Slider(
130
- label="Width",
131
- minimum=256,
132
- maximum=MAX_IMAGE_SIZE,
133
- step=32,
134
- value=1024,
135
- )
136
-
137
- height = gr.Slider(
138
- label="Height",
139
- minimum=256,
140
- maximum=MAX_IMAGE_SIZE,
141
- step=32,
142
- value=1024,
143
- )
144
-
145
- with gr.Row():
146
- num_inference_steps = gr.Slider(
147
- label="Number of inference steps",
148
- minimum=1,
149
- maximum=50,
150
- step=1,
151
- value=4,
152
- )
153
-
154
  gr.Examples(
155
  examples=[example[0] for example in few_shot_examples],
156
  inputs=[prompt],
157
  outputs=[result],
158
  fn=generate_image,
159
- cache_examples=False # Disable caching to avoid async issues
160
  )
161
 
162
  gr.on(
163
  triggers=[run_button.click, prompt.submit],
164
  fn=generate_image,
165
- inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
166
- outputs=[result, seed]
167
  )
168
 
169
  demo.launch(share=True)
 
22
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 512
 
26
 
27
  # Few-shot examples
28
  few_shot_examples = [
29
  ("Create a birthday card for friend", "A vibrant birthday card with a colorful confetti background, featuring a large, playful 'Happy Birthday!' in the center. The card has a fun, festive theme with balloons, streamers, and a cupcake with a single lit candle. The message inside reads, 'Wishing you a day full of laughter and joy!'"),
30
  ("An educational infographic showing the stages of the water cycle with bright, engaging visuals.", "An educational infographic illustrating the water cycle. The diagram shows labeled stages including evaporation, condensation, precipitation, and collection, with arrows guiding the flow. The colors are bright and engaging, with clouds, raindrops, and a sun. The design is simple and clear, suitable for a classroom setting."),
 
31
  ]
32
 
33
  def generate_detailed_prompt(user_input):
 
36
  Given a simple description, create an elaborate and detailed prompt that can be used to generate high-quality images.
37
  Your response should be concise and no longer than 3 sentences.
38
  Use the following examples as a guide for the level of detail and creativity expected:
39
+
40
  """ + "\n\n".join([f"Input: {input}\nOutput: {output}" for input, output in few_shot_examples]))
41
 
42
  human_message = HumanMessage(content=f"Generate a detailed image prompt based on this input, using no more than 3 sentences: {user_input}")
43
 
44
  response = llm([system_message, human_message])
45
  return response.content
46
+
47
  @spaces.GPU()
48
+ def generate_image(prompt, width=512, height=512, num_inference_steps=4):
49
+ seed = random.randint(0, MAX_SEED)
50
+ generator = torch.Generator(device=device).manual_seed(seed)
 
51
  image = pipe(
52
  prompt=prompt,
53
  width=width,
 
56
  generator=generator,
57
  guidance_scale=0.0
58
  ).images[0]
 
 
59
 
60
+ return image
61
 
62
  # Gradio UI setup
 
 
 
63
  css = """
64
  #col-container {
65
  margin: 0 auto;
 
80
  #result {
81
  margin-bottom: 20px;
82
  }
 
 
 
83
  """
84
 
85
  with gr.Blocks(css=css, theme='gradio/soft') as demo:
 
86
  with gr.Column(elem_id="col-container"):
87
  gr.Markdown("""
88
  # AI-Enhanced Image Generation
 
98
  container=False,
99
  elem_id="prompt"
100
  )
 
101
  run_button = gr.Button("Generate Image", scale=0)
102
 
103
  result = gr.Image(label="Result", show_label=False, elem_id="result")
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  gr.Examples(
106
  examples=[example[0] for example in few_shot_examples],
107
  inputs=[prompt],
108
  outputs=[result],
109
  fn=generate_image,
110
+ cache_examples=False
111
  )
112
 
113
  gr.on(
114
  triggers=[run_button.click, prompt.submit],
115
  fn=generate_image,
116
+ inputs=[prompt],
117
+ outputs=[result]
118
  )
119
 
120
  demo.launch(share=True)