ash123 commited on
Commit
343459a
1 Parent(s): 8d5a5ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -12,7 +12,7 @@ pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
12
  )
13
 
14
 
15
- def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
16
  rng = jax.random.PRNGKey(int(prng_seed))
17
  rng = jax.random.split(rng, jax.device_count())
18
  p_params = replicate(pipeline_params)
@@ -20,6 +20,8 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
20
  num_samples = 1
21
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
22
  prompt_ids = shard(prompt_ids)
 
 
23
 
24
  images = pipeline(
25
  prompt_ids=prompt_ids,
@@ -28,6 +30,8 @@ def generate_image(prompt: str, inference_steps: int = 30, prng_seed: int = 0):
28
  height=128,
29
  width=128,
30
  num_inference_steps=int(inference_steps),
 
 
31
  jit=True,
32
  ).images
33
 
@@ -267,12 +271,15 @@ with block:
267
  minimum=1, maximum=100, default=25, step=1, label="Inference Steps"
268
  )
269
  seed_input = gr.inputs.Number(default=0, label="Seed")
 
 
 
270
 
271
- ex = gr.Examples(examples=[["A watercolor painting of a bird", 25, 0],["A watercolor painting of an otter",25,0],["Marvel MCU deadpool, red mask, red shirt, red gloves, black shoulders, black elbow pads, black legs, gold buckle, black belt, black mask, white eyes, black boots, fuji low light color 35mm film, downtown Osaka alley at night out of focus in background, neon lights",25,0]], fn=generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input ],outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
272
  ex.dataset.headers = [""]
273
- negative.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
274
- prompt_input.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
275
- btn.click(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input], outputs=[gallery], postprocess=False)
276
 
277
  #advanced_button.click(
278
  # None,
 
12
  )
13
 
14
 
15
+ def generate_image(prompt: str,negative_prompt:str , inference_steps: int = 25, prng_seed: int = 0, guidance_scale: float = 9):
16
  rng = jax.random.PRNGKey(int(prng_seed))
17
  rng = jax.random.split(rng, jax.device_count())
18
  p_params = replicate(pipeline_params)
 
20
  num_samples = 1
21
  prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
22
  prompt_ids = shard(prompt_ids)
23
+ neg_prompt_ids = pipeline.prepare_inputs([negative_prompt] * num_samples)
24
+ neg_prompt_ids = shard(neg_prompt_ids)
25
 
26
  images = pipeline(
27
  prompt_ids=prompt_ids,
 
30
  height=128,
31
  width=128,
32
  num_inference_steps=int(inference_steps),
33
+ neg_prompt_ids=neg_prompt_ids,
34
+ guidance_scale =float(guidance_scale),
35
  jit=True,
36
  ).images
37
 
 
271
  minimum=1, maximum=100, default=25, step=1, label="Inference Steps"
272
  )
273
  seed_input = gr.inputs.Number(default=0, label="Seed")
274
+ guidance_scale = gr.Slider(
275
+ label="Guidance Scale", minimum=0, maximum=50, value=9, step=0.1
276
+ )
277
 
278
+ ex = gr.Examples(examples=[["A watercolor painting of a bird","mountain", 25, 0,9],["A watercolor painting of an otter","mountain",25,0,9],["Marvel MCU deadpool, red mask, red shirt, red gloves, black shoulders, black elbow pads, black legs, gold buckle, black belt, black mask, white eyes, black boots, fuji low light color 35mm film, downtown Osaka alley at night out of focus in background, neon lights","mountain",25,0,10]], fn=generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale],outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
279
  ex.dataset.headers = [""]
280
+ negative.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
281
+ prompt_input.submit(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
282
+ btn.click(generate_image, inputs=[prompt_input, negative, inf_steps_input,seed_input,guidance_scale], outputs=[gallery], postprocess=False)
283
 
284
  #advanced_button.click(
285
  # None,