kadirnar commited on
Commit
f2ab5fc
Β·
verified Β·
1 Parent(s): 57e5f35

Update stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +4 -1
stable_cascade.py CHANGED
@@ -5,7 +5,10 @@ import gradio as gr
5
 
6
  # Initialize the prior and decoder pipelines
7
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
 
 
8
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
 
9
 
10
  def generate_images(
11
  prompt="a photo of a girl",
@@ -53,7 +56,7 @@ def generate_images(
53
  num_inference_steps=decoder_inference_steps
54
  ).images
55
 
56
- return decoder_output[0]
57
 
58
 
59
  def web_demo():
 
5
 
6
  # Initialize the prior and decoder pipelines
7
  prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to("cuda")
8
+ prior.enable_xformers_memory_efficient_attention()
9
+
10
  decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=torch.float16).to("cuda")
11
+ decoder.enable_xformers_memory_efficient_attention()
12
 
13
  def generate_images(
14
  prompt="a photo of a girl",
 
56
  num_inference_steps=decoder_inference_steps
57
  ).images
58
 
59
+ return decoder_output
60
 
61
 
62
  def web_demo():