juancopi81 commited on
Commit
b5d5c28
β€’
1 Parent(s): 85b8db7

Add advanced option

Browse files
Files changed (1) hide show
  1. app.py +31 -14
app.py CHANGED
@@ -1,5 +1,7 @@
 
1
  import gradio as gr
2
  import torch
 
3
  from diffusers import DiffusionPipeline
4
  import streamlit as st
5
  from transformers import (
@@ -10,6 +12,8 @@ from transformers import (
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  device_dict = {"cuda": 0, "cpu": -1}
 
 
13
 
14
  # Add language detection pipeline
15
  language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
@@ -30,19 +34,21 @@ pipe = DiffusionPipeline.from_pretrained(
30
  detection_pipeline=language_detection_pipeline,
31
  translation_model=trans_model,
32
  translation_tokenizer=trans_tokenizer,
33
- # revision="fp16",
34
- # torch_dtype=torch.float16,
35
  )
36
 
37
- pipe.enable_attention_slicing()
38
  pipe = pipe.to(device)
39
 
40
  #torch.backends.cudnn.benchmark = True
41
  num_samples = 2
42
 
43
- def infer(prompt):
44
- output = pipe([prompt] * num_samples)
45
- return output.images
 
 
 
46
 
47
  css = """
48
  .gradio-container {
@@ -100,7 +106,6 @@ css = """
100
  border-radius: 14px !important;
101
  }
102
  #advanced-options {
103
- display: none;
104
  margin-bottom: 20px;
105
  }
106
  .footer {
@@ -167,13 +172,19 @@ block = gr.Blocks(css=css)
167
 
168
  examples = [
169
  [
170
- 'Una casa en la playa en un atardecer lluvioso'
 
 
171
  ],
172
  [
173
- 'Ein Hund, der Orange isst'
 
 
174
  ],
175
  [
176
- "Photo d'un restaurant parisien"
 
 
177
  ],
178
  ]
179
 
@@ -216,14 +227,20 @@ with block as demo:
216
  )
217
 
218
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="generated_id").style(
219
- grid=[2], height="auto"
220
  )
221
 
222
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text], outputs=gallery, cache_examples=False)
 
 
 
 
 
 
223
  ex.dataset.headers = [""]
224
 
225
- text.submit(infer, inputs=[text], outputs=gallery)
226
- btn.click(infer, inputs=[text], outputs=gallery)
227
 
228
  gr.HTML(
229
  """
 
1
+ from contextlib import nullcontext
2
  import gradio as gr
3
  import torch
4
+ from torch import autocast
5
  from diffusers import DiffusionPipeline
6
  import streamlit as st
7
  from transformers import (
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  device_dict = {"cuda": 0, "cpu": -1}
15
+ context = autocast if device == "cuda" else nullcontext
16
+ dtype = torch.float16 if device == "cuda" else torch.float32
17
 
18
  # Add language detection pipeline
19
  language_detection_model_ckpt = "papluca/xlm-roberta-base-language-detection"
 
34
  detection_pipeline=language_detection_pipeline,
35
  translation_model=trans_model,
36
  translation_tokenizer=trans_tokenizer,
37
+ revision="fp16",
38
+ torch_dtype=dtype,
39
  )
40
 
 
41
  pipe = pipe.to(device)
42
 
43
  #torch.backends.cudnn.benchmark = True
44
  num_samples = 2
45
 
46
+ def infer(prompt, scale, steps):
47
+
48
+ with context("cuda"):
49
+ images = pipe(num_samples*[prompt], guidance_scale=scale, num_inference_steps=steps).images
50
+
51
+ return images
52
 
53
  css = """
54
  .gradio-container {
 
106
  border-radius: 14px !important;
107
  }
108
  #advanced-options {
 
109
  margin-bottom: 20px;
110
  }
111
  .footer {
 
172
 
173
  examples = [
174
  [
175
+ 'Una casa en la playa en un atardecer lluvioso',
176
+ 45,
177
+ 7.5,
178
  ],
179
  [
180
+ 'Ein Hund, der Orange isst',
181
+ 45,
182
+ 7.5,
183
  ],
184
  [
185
+ "Photo d'un restaurant parisien",
186
+ 45,
187
+ 7.5,
188
  ],
189
  ]
190
 
 
227
  )
228
 
229
  gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="generated_id").style(
230
+ grid=[1], height="auto"
231
  )
232
 
233
+ with gr.Row(elem_id="advanced-options"):
234
+ steps = gr.Slider(label="Steps", minimum=5, maximum=50, value=45, step=5)
235
+ scale = gr.Slider(
236
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
237
+ )
238
+
239
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[text, steps, scale], outputs=gallery, cache_examples=False)
240
  ex.dataset.headers = [""]
241
 
242
+ text.submit(infer, inputs=[text, steps, scale], outputs=gallery)
243
+ btn.click(infer, inputs=[text, steps, scale], outputs=gallery)
244
 
245
  gr.HTML(
246
  """