AlanB commited on
Commit
1618ac2
1 Parent(s): 6093aac

Added vae_slicing option

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -0
pipeline.py CHANGED
@@ -91,6 +91,21 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
91
  # set slice_size = `None` to disable `attention slicing`
92
  self.enable_attention_slicing(None)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @torch.no_grad()
95
  def __call__(
96
  self,
 
91
  # set slice_size = `None` to disable `attention slicing`
92
  self.enable_attention_slicing(None)
93
 
94
+ def enable_vae_slicing(self):
95
+ r"""
96
+ Enable sliced VAE decoding.
97
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
98
+ steps. This is useful to save some memory and allow larger batch sizes.
99
+ """
100
+ self.vae.enable_slicing()
101
+
102
+ def disable_vae_slicing(self):
103
+ r"""
104
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
105
+ computing decoding in one step.
106
+ """
107
+ self.vae.disable_slicing()
108
+
109
  @torch.no_grad()
110
  def __call__(
111
  self,