Added vae_slicing option
Browse files- pipeline.py +15 -0
pipeline.py
CHANGED
@@ -91,6 +91,21 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
|
91 |
# set slice_size = `None` to disable `attention slicing`
|
92 |
self.enable_attention_slicing(None)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
@torch.no_grad()
|
95 |
def __call__(
|
96 |
self,
|
|
|
91 |
# set slice_size = `None` to disable `attention slicing`
|
92 |
self.enable_attention_slicing(None)
|
93 |
|
94 |
+
def enable_vae_slicing(self):
|
95 |
+
r"""
|
96 |
+
Enable sliced VAE decoding.
|
97 |
+
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
98 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
99 |
+
"""
|
100 |
+
self.vae.enable_slicing()
|
101 |
+
|
102 |
+
def disable_vae_slicing(self):
|
103 |
+
r"""
|
104 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
|
105 |
+
computing decoding in one step.
|
106 |
+
"""
|
107 |
+
self.vae.disable_slicing()
|
108 |
+
|
109 |
@torch.no_grad()
|
110 |
def __call__(
|
111 |
self,
|