Update pipeline.py
Browse files- pipeline.py +60 -25
pipeline.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1 |
import inspect
|
2 |
import re
|
3 |
-
import
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from typing import Callable, List, Optional, Union
|
7 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
8 |
|
|
|
9 |
from diffusers.configuration_utils import FrozenDict
|
10 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
11 |
from diffusers.pipeline_utils import DiffusionPipeline
|
12 |
-
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
13 |
-
from diffusers.utils import deprecate, logging
|
14 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
15 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
|
|
|
|
|
|
|
16 |
|
17 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
18 |
|
@@ -130,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
|
|
130 |
"""
|
131 |
tokens = []
|
132 |
weights = []
|
|
|
133 |
for text in prompt:
|
134 |
texts_and_weights = parse_prompt_attention(text)
|
135 |
text_token = []
|
@@ -138,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
|
|
138 |
# tokenize and discard the starting and the ending token
|
139 |
token = pipe.tokenizer(word).input_ids[1:-1]
|
140 |
text_token += token
|
141 |
-
|
142 |
# copy the weight by length of token
|
143 |
text_weight += [weight] * len(token)
|
144 |
-
|
145 |
# stop if the text is too long (longer than truncation limit)
|
146 |
if len(text_token) > max_length:
|
|
|
147 |
break
|
148 |
-
|
149 |
# truncate
|
150 |
if len(text_token) > max_length:
|
|
|
151 |
text_token = text_token[:max_length]
|
152 |
text_weight = text_weight[:max_length]
|
153 |
-
|
154 |
tokens.append(text_token)
|
155 |
weights.append(text_weight)
|
|
|
|
|
156 |
return tokens, weights
|
157 |
|
158 |
|
@@ -171,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
|
171 |
if len(weights[i]) == 0:
|
172 |
w = [1.0] * weights_length
|
173 |
else:
|
174 |
-
for j in range(
|
175 |
w.append(1.0) # weight for starting token in this chunk
|
176 |
-
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
177 |
w.append(1.0) # weight for ending token in this chunk
|
178 |
w += [1.0] * (weights_length - len(w))
|
179 |
weights[i] = w[:]
|
@@ -182,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
|
182 |
|
183 |
|
184 |
def get_unweighted_text_embeddings(
|
185 |
-
pipe: DiffusionPipeline,
|
|
|
|
|
|
|
186 |
):
|
187 |
"""
|
188 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
@@ -283,7 +289,8 @@ def get_weighted_text_embeddings(
|
|
283 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
284 |
|
285 |
max_embeddings_multiples = min(
|
286 |
-
max_embeddings_multiples,
|
|
|
287 |
)
|
288 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
289 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
@@ -315,12 +322,18 @@ def get_weighted_text_embeddings(
|
|
315 |
|
316 |
# get the embeddings
|
317 |
text_embeddings = get_unweighted_text_embeddings(
|
318 |
-
pipe,
|
|
|
|
|
|
|
319 |
)
|
320 |
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
321 |
if uncond_prompt is not None:
|
322 |
uncond_embeddings = get_unweighted_text_embeddings(
|
323 |
-
pipe,
|
|
|
|
|
|
|
324 |
)
|
325 |
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
326 |
|
@@ -630,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
630 |
# Unlike in other pipelines, latents need to be generated in the target device
|
631 |
# for 1-to-1 results reproducibility with the CompVis implementation.
|
632 |
# However this currently doesn't work in `mps`.
|
633 |
-
latents_shape = (
|
|
|
|
|
|
|
|
|
|
|
634 |
|
635 |
if latents is None:
|
636 |
if self.device.type == "mps":
|
637 |
# randn does not exist on mps
|
638 |
-
latents = torch.randn(
|
639 |
-
|
640 |
-
|
|
|
|
|
|
|
641 |
else:
|
642 |
-
latents = torch.randn(
|
|
|
|
|
|
|
|
|
|
|
643 |
else:
|
644 |
if latents.shape != latents_shape:
|
645 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
@@ -682,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
682 |
# add noise to latents using the timesteps
|
683 |
if self.device.type == "mps":
|
684 |
# randn does not exist on mps
|
685 |
-
noise = torch.randn(
|
686 |
-
|
687 |
-
|
|
|
|
|
|
|
688 |
else:
|
689 |
-
noise = torch.randn(
|
|
|
|
|
|
|
|
|
|
|
690 |
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
691 |
|
692 |
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
@@ -739,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
739 |
self.device
|
740 |
)
|
741 |
image, has_nsfw_concept = self.safety_checker(
|
742 |
-
images=image,
|
|
|
743 |
)
|
744 |
else:
|
745 |
has_nsfw_concept = None
|
|
|
1 |
import inspect
|
2 |
import re
|
3 |
+
from typing import Callable, List, Optional, Union
|
4 |
+
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
|
|
7 |
|
8 |
+
import PIL
|
9 |
from diffusers.configuration_utils import FrozenDict
|
10 |
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
11 |
from diffusers.pipeline_utils import DiffusionPipeline
|
|
|
|
|
12 |
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
13 |
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
14 |
+
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
15 |
+
from diffusers.utils import deprecate, logging
|
16 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
17 |
+
|
18 |
|
19 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
20 |
|
|
|
132 |
"""
|
133 |
tokens = []
|
134 |
weights = []
|
135 |
+
truncated = False
|
136 |
for text in prompt:
|
137 |
texts_and_weights = parse_prompt_attention(text)
|
138 |
text_token = []
|
|
|
141 |
# tokenize and discard the starting and the ending token
|
142 |
token = pipe.tokenizer(word).input_ids[1:-1]
|
143 |
text_token += token
|
|
|
144 |
# copy the weight by length of token
|
145 |
text_weight += [weight] * len(token)
|
|
|
146 |
# stop if the text is too long (longer than truncation limit)
|
147 |
if len(text_token) > max_length:
|
148 |
+
truncated = True
|
149 |
break
|
|
|
150 |
# truncate
|
151 |
if len(text_token) > max_length:
|
152 |
+
truncated = True
|
153 |
text_token = text_token[:max_length]
|
154 |
text_weight = text_weight[:max_length]
|
|
|
155 |
tokens.append(text_token)
|
156 |
weights.append(text_weight)
|
157 |
+
if truncated:
|
158 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
159 |
return tokens, weights
|
160 |
|
161 |
|
|
|
174 |
if len(weights[i]) == 0:
|
175 |
w = [1.0] * weights_length
|
176 |
else:
|
177 |
+
for j in range(max_embeddings_multiples):
|
178 |
w.append(1.0) # weight for starting token in this chunk
|
179 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
180 |
w.append(1.0) # weight for ending token in this chunk
|
181 |
w += [1.0] * (weights_length - len(w))
|
182 |
weights[i] = w[:]
|
|
|
185 |
|
186 |
|
187 |
def get_unweighted_text_embeddings(
|
188 |
+
pipe: DiffusionPipeline,
|
189 |
+
text_input: torch.Tensor,
|
190 |
+
chunk_length: int,
|
191 |
+
no_boseos_middle: Optional[bool] = True,
|
192 |
):
|
193 |
"""
|
194 |
When the length of tokens is a multiple of the capacity of the text encoder,
|
|
|
289 |
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
290 |
|
291 |
max_embeddings_multiples = min(
|
292 |
+
max_embeddings_multiples,
|
293 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
294 |
)
|
295 |
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
296 |
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
|
|
322 |
|
323 |
# get the embeddings
|
324 |
text_embeddings = get_unweighted_text_embeddings(
|
325 |
+
pipe,
|
326 |
+
prompt_tokens,
|
327 |
+
pipe.tokenizer.model_max_length,
|
328 |
+
no_boseos_middle=no_boseos_middle,
|
329 |
)
|
330 |
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
331 |
if uncond_prompt is not None:
|
332 |
uncond_embeddings = get_unweighted_text_embeddings(
|
333 |
+
pipe,
|
334 |
+
uncond_tokens,
|
335 |
+
pipe.tokenizer.model_max_length,
|
336 |
+
no_boseos_middle=no_boseos_middle,
|
337 |
)
|
338 |
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
339 |
|
|
|
643 |
# Unlike in other pipelines, latents need to be generated in the target device
|
644 |
# for 1-to-1 results reproducibility with the CompVis implementation.
|
645 |
# However this currently doesn't work in `mps`.
|
646 |
+
latents_shape = (
|
647 |
+
batch_size * num_images_per_prompt,
|
648 |
+
self.unet.in_channels,
|
649 |
+
height // 8,
|
650 |
+
width // 8,
|
651 |
+
)
|
652 |
|
653 |
if latents is None:
|
654 |
if self.device.type == "mps":
|
655 |
# randn does not exist on mps
|
656 |
+
latents = torch.randn(
|
657 |
+
latents_shape,
|
658 |
+
generator=generator,
|
659 |
+
device="cpu",
|
660 |
+
dtype=latents_dtype,
|
661 |
+
).to(self.device)
|
662 |
else:
|
663 |
+
latents = torch.randn(
|
664 |
+
latents_shape,
|
665 |
+
generator=generator,
|
666 |
+
device=self.device,
|
667 |
+
dtype=latents_dtype,
|
668 |
+
)
|
669 |
else:
|
670 |
if latents.shape != latents_shape:
|
671 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
|
708 |
# add noise to latents using the timesteps
|
709 |
if self.device.type == "mps":
|
710 |
# randn does not exist on mps
|
711 |
+
noise = torch.randn(
|
712 |
+
init_latents.shape,
|
713 |
+
generator=generator,
|
714 |
+
device="cpu",
|
715 |
+
dtype=latents_dtype,
|
716 |
+
).to(self.device)
|
717 |
else:
|
718 |
+
noise = torch.randn(
|
719 |
+
init_latents.shape,
|
720 |
+
generator=generator,
|
721 |
+
device=self.device,
|
722 |
+
dtype=latents_dtype,
|
723 |
+
)
|
724 |
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
725 |
|
726 |
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
|
|
773 |
self.device
|
774 |
)
|
775 |
image, has_nsfw_concept = self.safety_checker(
|
776 |
+
images=image,
|
777 |
+
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
|
778 |
)
|
779 |
else:
|
780 |
has_nsfw_concept = None
|