copy from diffusers
Browse files- latent_consistency_controlnet.py +20 -15
latent_consistency_controlnet.py
CHANGED
@@ -25,7 +25,6 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
|
25 |
|
26 |
from diffusers import (
|
27 |
AutoencoderKL,
|
28 |
-
AutoencoderTiny,
|
29 |
ConfigMixin,
|
30 |
DiffusionPipeline,
|
31 |
SchedulerMixin,
|
@@ -50,6 +49,17 @@ import PIL.Image
|
|
50 |
|
51 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
|
54 |
_optional_components = ["scheduler"]
|
55 |
|
@@ -276,22 +286,17 @@ class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
|
|
276 |
)
|
277 |
|
278 |
elif isinstance(generator, list):
|
279 |
-
|
280 |
-
|
281 |
-
self.vae.encode(image[i : i + 1])
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
init_latents = [
|
286 |
-
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i])
|
287 |
-
for i in range(batch_size)
|
288 |
-
]
|
289 |
init_latents = torch.cat(init_latents, dim=0)
|
290 |
else:
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
295 |
|
296 |
init_latents = self.vae.config.scaling_factor * init_latents
|
297 |
|
|
|
25 |
|
26 |
from diffusers import (
|
27 |
AutoencoderKL,
|
|
|
28 |
ConfigMixin,
|
29 |
DiffusionPipeline,
|
30 |
SchedulerMixin,
|
|
|
49 |
|
50 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
|
52 |
+
|
53 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
54 |
+
def retrieve_latents(encoder_output, generator):
|
55 |
+
if hasattr(encoder_output, "latent_dist"):
|
56 |
+
return encoder_output.latent_dist.sample(generator)
|
57 |
+
elif hasattr(encoder_output, "latents"):
|
58 |
+
return encoder_output.latents
|
59 |
+
else:
|
60 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
61 |
+
|
62 |
+
|
63 |
class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline):
|
64 |
_optional_components = ["scheduler"]
|
65 |
|
|
|
286 |
)
|
287 |
|
288 |
elif isinstance(generator, list):
|
289 |
+
init_latents = [
|
290 |
+
retrieve_latents(
|
291 |
+
self.vae.encode(image[i : i + 1]), generator=generator[i]
|
292 |
+
)
|
293 |
+
for i in range(batch_size)
|
294 |
+
]
|
|
|
|
|
|
|
|
|
295 |
init_latents = torch.cat(init_latents, dim=0)
|
296 |
else:
|
297 |
+
init_latents = retrieve_latents(
|
298 |
+
self.vae.encode(image), generator=generator
|
299 |
+
)
|
|
|
300 |
|
301 |
init_latents = self.vae.config.scaling_factor * init_latents
|
302 |
|