ehristoforu commited on
Commit
5bdb9be
1 Parent(s): 99df419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -2,6 +2,7 @@ import torch
2
  import gradio as gr
3
  from PIL import Image
4
  import spaces
 
5
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
6
 
7
  device = "cuda"
@@ -40,9 +41,13 @@ footer {
40
  """
41
 
42
  @spaces.GPU
43
- def gen(prompt, negative, width, height):
 
 
 
 
44
  prior_output = prior(
45
- prompt=f"{prompt}, {prompt_add}",
46
  height=height,
47
  width=width,
48
  negative_prompt=negative,
@@ -52,7 +57,7 @@ def gen(prompt, negative, width, height):
52
  )
53
  decoder_output = decoder(
54
  image_embeddings=prior_output.image_embeddings.half(),
55
- prompt=f"{prompt}, {prompt_add}",
56
  negative_prompt=negative,
57
  guidance_scale=0.0,
58
  output_type="pil",
@@ -71,9 +76,11 @@ with gr.Blocks(css=css) as demo:
71
  with gr.Row():
72
  width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
73
  height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
 
 
74
  with gr.Row():
75
  gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
76
 
77
- button.click(gen, inputs=[prompt, negative, width, height], outputs=gallery)
78
 
79
- demo.launch(show_api=False)
 
2
  import gradio as gr
3
  from PIL import Image
4
  import spaces
5
+ import tqdm
6
  from diffusers import StableCascadeDecoderPipeline, StableCascadePriorPipeline
7
 
8
  device = "cuda"
 
41
  """
42
 
43
  @spaces.GPU
44
+ def gen(prompt, negative, width, height, use_add, progress=gr.Progress()):
45
+ if use_add:
46
+ text = f"{prompt}, {prompt_add}"
47
+ else:
48
+ text = f"{prompt"
49
  prior_output = prior(
50
+ prompt=text,
51
  height=height,
52
  width=width,
53
  negative_prompt=negative,
 
57
  )
58
  decoder_output = decoder(
59
  image_embeddings=prior_output.image_embeddings.half(),
60
+ prompt=text,
61
  negative_prompt=negative,
62
  guidance_scale=0.0,
63
  output_type="pil",
 
76
  with gr.Row():
77
  width = gr.Slider(label="Width", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
78
  height = gr.Slider(label="Height", minimum=1024, maximum=2048, step=8, value=1024, interactive=True)
79
+ with gr.Row():
80
+ use_add = gr.Checkbox(label="Use prompt addition", value=True, interactive=True)
81
  with gr.Row():
82
  gallery = gr.Gallery(show_label=False, rows=1, columns=1, allow_preview=True, preview=True)
83
 
84
+ button.click(gen, inputs=[prompt, negative, width, height, use_add], outputs=gallery)
85
 
86
+ demo.queue().launch(show_api=False)