xiaozeng commited on
Commit
05654ff
•
1 Parent(s): 7594647

Upload with huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -6
  2. app.py +1677 -0
  3. env.py +13 -0
  4. ppdiffusers/__init__.py +162 -0
  5. ppdiffusers/__pycache__/__init__.cpython-37.pyc +0 -0
  6. ppdiffusers/__pycache__/configuration_utils.cpython-37.pyc +0 -0
  7. ppdiffusers/__pycache__/download_utils.cpython-37.pyc +0 -0
  8. ppdiffusers/__pycache__/fastdeploy_utils.cpython-37.pyc +0 -0
  9. ppdiffusers/__pycache__/initializer.cpython-37.pyc +0 -0
  10. ppdiffusers/__pycache__/loaders.cpython-37.pyc +0 -0
  11. ppdiffusers/__pycache__/modeling_utils.cpython-37.pyc +0 -0
  12. ppdiffusers/__pycache__/optimization.cpython-37.pyc +0 -0
  13. ppdiffusers/__pycache__/pipeline_utils.cpython-37.pyc +0 -0
  14. ppdiffusers/__pycache__/ppnlp_patch_utils.cpython-37.pyc +0 -0
  15. ppdiffusers/__pycache__/training_utils.cpython-37.pyc +0 -0
  16. ppdiffusers/__pycache__/version.cpython-37.pyc +0 -0
  17. ppdiffusers/commands/__init__.py +28 -0
  18. ppdiffusers/commands/env.py +67 -0
  19. ppdiffusers/commands/ppdiffusers_cli.py +41 -0
  20. ppdiffusers/configuration_utils.py +591 -0
  21. ppdiffusers/download_utils.py +44 -0
  22. ppdiffusers/experimental/README.md +6 -0
  23. ppdiffusers/experimental/__init__.py +17 -0
  24. ppdiffusers/experimental/rl/__init__.py +17 -0
  25. ppdiffusers/experimental/rl/value_guided_sampling.py +146 -0
  26. ppdiffusers/fastdeploy_utils.py +260 -0
  27. ppdiffusers/initializer.py +303 -0
  28. ppdiffusers/loaders.py +190 -0
  29. ppdiffusers/modeling_paddle_pytorch_utils.py +106 -0
  30. ppdiffusers/modeling_utils.py +619 -0
  31. ppdiffusers/models/__init__.py +25 -0
  32. ppdiffusers/models/__pycache__/__init__.cpython-37.pyc +0 -0
  33. ppdiffusers/models/__pycache__/attention.cpython-37.pyc +0 -0
  34. ppdiffusers/models/__pycache__/cross_attention.cpython-37.pyc +0 -0
  35. ppdiffusers/models/__pycache__/embeddings.cpython-37.pyc +0 -0
  36. ppdiffusers/models/__pycache__/prior_transformer.cpython-37.pyc +0 -0
  37. ppdiffusers/models/__pycache__/resnet.cpython-37.pyc +0 -0
  38. ppdiffusers/models/__pycache__/unet_1d.cpython-37.pyc +0 -0
  39. ppdiffusers/models/__pycache__/unet_1d_blocks.cpython-37.pyc +0 -0
  40. ppdiffusers/models/__pycache__/unet_2d.cpython-37.pyc +0 -0
  41. ppdiffusers/models/__pycache__/unet_2d_blocks.cpython-37.pyc +0 -0
  42. ppdiffusers/models/__pycache__/unet_2d_condition.cpython-37.pyc +0 -0
  43. ppdiffusers/models/__pycache__/vae.cpython-37.pyc +0 -0
  44. ppdiffusers/models/attention.py +683 -0
  45. ppdiffusers/models/cross_attention.py +435 -0
  46. ppdiffusers/models/ema.py +103 -0
  47. ppdiffusers/models/embeddings.py +199 -0
  48. ppdiffusers/models/prior_transformer.py +220 -0
  49. ppdiffusers/models/resnet.py +716 -0
  50. ppdiffusers/models/unet_1d.py +247 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Lora Test
3
- emoji: 🐨
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LoRa ppdiffusers dreambooth
3
+ emoji: 🎨🎞️
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.18.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import gradio as gr
16
+ from env import BASE_MODEL_NAME, LORA_WEIGHTS_PATH, PROMPTS
17
+
18
+ examples = [
19
+ [
20
+ PROMPTS,
21
+ 'low quality',
22
+ 7.5,
23
+ 512,
24
+ 512,
25
+ 25,
26
+ "DPMSolver"
27
+ ],
28
+ ]
29
+ import inspect
30
+ import os
31
+ import random
32
+ import re
33
+ import time
34
+ from typing import Callable, List, Optional, Union
35
+
36
+ import numpy as np
37
+ import paddle
38
+ import PIL
39
+ import PIL.Image
40
+ from packaging import version
41
+
42
+ from paddlenlp.transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
43
+
44
+ from ppdiffusers.configuration_utils import FrozenDict
45
+ from ppdiffusers.models import AutoencoderKL, UNet2DConditionModel
46
+ from ppdiffusers.pipeline_utils import DiffusionPipeline
47
+ from ppdiffusers.schedulers import (
48
+ DDIMScheduler,
49
+ DPMSolverMultistepScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ EulerDiscreteScheduler,
52
+ LMSDiscreteScheduler,
53
+ PNDMScheduler,
54
+ HeunDiscreteScheduler,
55
+ KDPM2AncestralDiscreteScheduler,
56
+ KDPM2DiscreteScheduler,
57
+
58
+ )
59
+ from ppdiffusers.utils import PIL_INTERPOLATION, deprecate, logging
60
+ from ppdiffusers.utils.testing_utils import load_image
61
+ from ppdiffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
62
+ from ppdiffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
63
+
64
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
65
+
66
+
67
+ def save_all(images, FORMAT="jpg", OUTDIR="./outputs/"):
68
+ if not isinstance(images, (list, tuple)):
69
+ images = [images]
70
+ for image in images:
71
+ PRECISION = "fp32"
72
+ argument = image.argument
73
+ os.makedirs(OUTDIR, exist_ok=True)
74
+ epoch_time = argument["epoch_time"]
75
+ PROMPT = argument["prompt"]
76
+ NEGPROMPT = argument["negative_prompt"]
77
+ HEIGHT = argument["height"]
78
+ WIDTH = argument["width"]
79
+ SEED = argument["seed"]
80
+ STRENGTH = argument.get("strength", 1)
81
+ INFERENCE_STEPS = argument["num_inference_steps"]
82
+ GUIDANCE_SCALE = argument["guidance_scale"]
83
+
84
+ filename = f"{str(epoch_time)}_scale_{GUIDANCE_SCALE}_steps_{INFERENCE_STEPS}_seed_{SEED}.{FORMAT}"
85
+ filedir = f"{OUTDIR}/{filename}"
86
+ image.save(filedir)
87
+ with open(f"{OUTDIR}/{epoch_time}_prompt.txt", "w") as file:
88
+ file.write(
89
+ f"PROMPT: {PROMPT}\nNEG_PROMPT: {NEGPROMPT}\n\nINFERENCE_STEPS: {INFERENCE_STEPS}\nHeight: {HEIGHT}\nWidth: {WIDTH}\nSeed: {SEED}\n\nPrecision: {PRECISION}\nSTRENGTH: {STRENGTH}\nGUIDANCE_SCALE: {GUIDANCE_SCALE}"
90
+ )
91
+
92
+
93
+ re_attention = re.compile(
94
+ r"""
95
+ \\\(|
96
+ \\\)|
97
+ \\\[|
98
+ \\]|
99
+ \\\\|
100
+ \\|
101
+ \(|
102
+ \[|
103
+ :([+-]?[.\d]+)\)|
104
+ \)|
105
+ ]|
106
+ [^\\()\[\]:]+|
107
+ :
108
+ """,
109
+ re.X,
110
+ )
111
+
112
+
113
+ def parse_prompt_attention(text):
114
+ """
115
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
116
+ Accepted tokens are:
117
+ (abc) - increases attention to abc by a multiplier of 1.1
118
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
119
+ [abc] - decreases attention to abc by a multiplier of 1.1
120
+ \( - literal character '('
121
+ \[ - literal character '['
122
+ \) - literal character ')'
123
+ \] - literal character ']'
124
+ \\ - literal character '\'
125
+ anything else - just text
126
+ >>> parse_prompt_attention('normal text')
127
+ [['normal text', 1.0]]
128
+ >>> parse_prompt_attention('an (important) word')
129
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
130
+ >>> parse_prompt_attention('(unbalanced')
131
+ [['unbalanced', 1.1]]
132
+ >>> parse_prompt_attention('\(literal\]')
133
+ [['(literal]', 1.0]]
134
+ >>> parse_prompt_attention('(unnecessary)(parens)')
135
+ [['unnecessaryparens', 1.1]]
136
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
137
+ [['a ', 1.0],
138
+ ['house', 1.5730000000000004],
139
+ [' ', 1.1],
140
+ ['on', 1.0],
141
+ [' a ', 1.1],
142
+ ['hill', 0.55],
143
+ [', sun, ', 1.1],
144
+ ['sky', 1.4641000000000006],
145
+ ['.', 1.1]]
146
+ """
147
+
148
+ res = []
149
+ round_brackets = []
150
+ square_brackets = []
151
+
152
+ round_bracket_multiplier = 1.1
153
+ square_bracket_multiplier = 1 / 1.1
154
+
155
+ def multiply_range(start_position, multiplier):
156
+ for p in range(start_position, len(res)):
157
+ res[p][1] *= multiplier
158
+
159
+ for m in re_attention.finditer(text):
160
+ text = m.group(0)
161
+ weight = m.group(1)
162
+
163
+ if text.startswith("\\"):
164
+ res.append([text[1:], 1.0])
165
+ elif text == "(":
166
+ round_brackets.append(len(res))
167
+ elif text == "[":
168
+ square_brackets.append(len(res))
169
+ elif weight is not None and len(round_brackets) > 0:
170
+ multiply_range(round_brackets.pop(), float(weight))
171
+ elif text == ")" and len(round_brackets) > 0:
172
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
173
+ elif text == "]" and len(square_brackets) > 0:
174
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
175
+ else:
176
+ res.append([text, 1.0])
177
+
178
+ for pos in round_brackets:
179
+ multiply_range(pos, round_bracket_multiplier)
180
+
181
+ for pos in square_brackets:
182
+ multiply_range(pos, square_bracket_multiplier)
183
+
184
+ if len(res) == 0:
185
+ res = [["", 1.0]]
186
+
187
+ # merge runs of identical weights
188
+ i = 0
189
+ while i + 1 < len(res):
190
+ if res[i][1] == res[i + 1][1]:
191
+ res[i][0] += res[i + 1][0]
192
+ res.pop(i + 1)
193
+ else:
194
+ i += 1
195
+
196
+ return res
197
+
198
+
199
+ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
200
+ r"""
201
+ Tokenize a list of prompts and return its tokens with weights of each token.
202
+
203
+ No padding, starting or ending token is included.
204
+ """
205
+ tokens = []
206
+ weights = []
207
+ for text in prompt:
208
+ texts_and_weights = parse_prompt_attention(text)
209
+ text_token = []
210
+ text_weight = []
211
+ for word, weight in texts_and_weights:
212
+ # tokenize and discard the starting and the ending token
213
+ token = pipe.tokenizer(word).input_ids[1:-1]
214
+ text_token += token
215
+
216
+ # copy the weight by length of token
217
+ text_weight += [weight] * len(token)
218
+
219
+ # stop if the text is too long (longer than truncation limit)
220
+ if len(text_token) > max_length:
221
+ break
222
+
223
+ # truncate
224
+ if len(text_token) > max_length:
225
+ text_token = text_token[:max_length]
226
+ text_weight = text_weight[:max_length]
227
+
228
+ tokens.append(text_token)
229
+ weights.append(text_weight)
230
+ return tokens, weights
231
+
232
+
233
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77):
234
+ r"""
235
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
236
+ """
237
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
238
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
239
+ for i in range(len(tokens)):
240
+ tokens[i] = [bos] + tokens[i] + [eos] + [pad] * (max_length - 2 - len(tokens[i]))
241
+ if no_boseos_middle:
242
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
243
+ else:
244
+ w = []
245
+ if len(weights[i]) == 0:
246
+ w = [1.0] * weights_length
247
+ else:
248
+ for j in range((len(weights[i]) - 1) // chunk_length + 1):
249
+ w.append(1.0) # weight for starting token in this chunk
250
+ w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
251
+ w.append(1.0) # weight for ending token in this chunk
252
+ w += [1.0] * (weights_length - len(w))
253
+ weights[i] = w[:]
254
+
255
+ return tokens, weights
256
+
257
+
258
+ def get_unweighted_text_embeddings(
259
+ pipe: DiffusionPipeline, text_input: paddle.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
260
+ ):
261
+ """
262
+ When the length of tokens is a multiple of the capacity of the text encoder,
263
+ it should be split into chunks and sent to the text encoder individually.
264
+ """
265
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
266
+ if max_embeddings_multiples > 1:
267
+ text_embeddings = []
268
+ for i in range(max_embeddings_multiples):
269
+ # extract the i-th chunk
270
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
271
+
272
+ # cover the head and the tail by the starting and the ending tokens
273
+ text_input_chunk[:, 0] = text_input[0, 0]
274
+ text_input_chunk[:, -1] = text_input[0, -1]
275
+
276
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
277
+
278
+ if no_boseos_middle:
279
+ if i == 0:
280
+ # discard the ending token
281
+ text_embedding = text_embedding[:, :-1]
282
+ elif i == max_embeddings_multiples - 1:
283
+ # discard the starting token
284
+ text_embedding = text_embedding[:, 1:]
285
+ else:
286
+ # discard both starting and ending tokens
287
+ text_embedding = text_embedding[:, 1:-1]
288
+
289
+ text_embeddings.append(text_embedding)
290
+ text_embeddings = paddle.concat(text_embeddings, axis=1)
291
+ else:
292
+ text_embeddings = pipe.text_encoder(text_input)[0]
293
+ return text_embeddings
294
+
295
+
296
+ def get_weighted_text_embeddings(
297
+ pipe: DiffusionPipeline,
298
+ prompt: Union[str, List[str]],
299
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
300
+ max_embeddings_multiples: Optional[int] = 1,
301
+ no_boseos_middle: Optional[bool] = False,
302
+ skip_parsing: Optional[bool] = False,
303
+ skip_weighting: Optional[bool] = False,
304
+ **kwargs
305
+ ):
306
+ r"""
307
+ Prompts can be assigned with local weights using brackets. For example,
308
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
309
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
310
+
311
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
312
+
313
+ Args:
314
+ pipe (`DiffusionPipeline`):
315
+ Pipe to provide access to the tokenizer and the text encoder.
316
+ prompt (`str` or `List[str]`):
317
+ The prompt or prompts to guide the image generation.
318
+ uncond_prompt (`str` or `List[str]`):
319
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
320
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
321
+ max_embeddings_multiples (`int`, *optional*, defaults to `1`):
322
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
323
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
324
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
325
+ ending token in each of the chunk in the middle.
326
+ skip_parsing (`bool`, *optional*, defaults to `False`):
327
+ Skip the parsing of brackets.
328
+ skip_weighting (`bool`, *optional*, defaults to `False`):
329
+ Skip the weighting. When the parsing is skipped, it is forced True.
330
+ """
331
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
332
+ if isinstance(prompt, str):
333
+ prompt = [prompt]
334
+
335
+ if not skip_parsing:
336
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
337
+ if uncond_prompt is not None:
338
+ if isinstance(uncond_prompt, str):
339
+ uncond_prompt = [uncond_prompt]
340
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
341
+ else:
342
+ prompt_tokens = [
343
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
344
+ ]
345
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
346
+ if uncond_prompt is not None:
347
+ if isinstance(uncond_prompt, str):
348
+ uncond_prompt = [uncond_prompt]
349
+ uncond_tokens = [
350
+ token[1:-1]
351
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
352
+ ]
353
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
354
+
355
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
356
+ max_length = max([len(token) for token in prompt_tokens])
357
+ if uncond_prompt is not None:
358
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
359
+
360
+ max_embeddings_multiples = min(
361
+ max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
362
+ )
363
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
364
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
365
+
366
+ # pad the length of tokens and weights
367
+ # support bert tokenizer
368
+ bos = pipe.tokenizer.bos_token_id if pipe.tokenizer.bos_token_id is not None else pipe.tokenizer.cls_token_id
369
+ eos = pipe.tokenizer.eos_token_id if pipe.tokenizer.eos_token_id is not None else pipe.tokenizer.sep_token_id
370
+ pad = pipe.tokenizer.pad_token_id
371
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
372
+ prompt_tokens,
373
+ prompt_weights,
374
+ max_length,
375
+ bos,
376
+ eos,
377
+ pad,
378
+ no_boseos_middle=no_boseos_middle,
379
+ chunk_length=pipe.tokenizer.model_max_length,
380
+ )
381
+ prompt_tokens = paddle.to_tensor(prompt_tokens)
382
+ if uncond_prompt is not None:
383
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
384
+ uncond_tokens,
385
+ uncond_weights,
386
+ max_length,
387
+ bos,
388
+ eos,
389
+ pad,
390
+ no_boseos_middle=no_boseos_middle,
391
+ chunk_length=pipe.tokenizer.model_max_length,
392
+ )
393
+ uncond_tokens = paddle.to_tensor(uncond_tokens)
394
+
395
+ # get the embeddings
396
+ text_embeddings = get_unweighted_text_embeddings(
397
+ pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
398
+ )
399
+ prompt_weights = paddle.to_tensor(prompt_weights, dtype=text_embeddings.dtype)
400
+ if uncond_prompt is not None:
401
+ uncond_embeddings = get_unweighted_text_embeddings(
402
+ pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
403
+ )
404
+ uncond_weights = paddle.to_tensor(uncond_weights, dtype=uncond_embeddings.dtype)
405
+
406
+ # assign weights to the prompts and normalize in the sense of mean
407
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
408
+ if (not skip_parsing) and (not skip_weighting):
409
+ previous_mean = text_embeddings.mean(axis=[-2, -1])
410
+ text_embeddings *= prompt_weights.unsqueeze(-1)
411
+ text_embeddings *= previous_mean / text_embeddings.mean(axis=[-2, -1])
412
+ if uncond_prompt is not None:
413
+ previous_mean = uncond_embeddings.mean(axis=[-2, -1])
414
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
415
+ uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
416
+
417
+ # For classifier free guidance, we need to do two forward passes.
418
+ # Here we concatenate the unconditional and text embeddings into a single batch
419
+ # to avoid doing two forward passes
420
+ if uncond_prompt is not None:
421
+ text_embeddings = paddle.concat([uncond_embeddings, text_embeddings])
422
+
423
+ return text_embeddings
424
+
425
+
426
+ def preprocess_image(image):
427
+ w, h = image.size
428
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
429
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
430
+ image = np.array(image).astype(np.float32) / 255.0
431
+ image = image[None].transpose(0, 3, 1, 2)
432
+ image = paddle.to_tensor(image)
433
+ return 2.0 * image - 1.0
434
+
435
+
436
+ def preprocess_mask(mask):
437
+ mask = mask.convert("L")
438
+ w, h = mask.size
439
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
440
+ mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
441
+ mask = np.array(mask).astype(np.float32) / 255.0
442
+ mask = np.tile(mask, (4, 1, 1))
443
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
444
+ mask = 1 - mask # repaint white, keep black
445
+ mask = paddle.to_tensor(mask)
446
+ return mask
447
+
448
+
449
+ class StableDiffusionPipelineAllinOne(DiffusionPipeline):
450
+ r"""
451
+ Pipeline for text-to-image image-to-image inpainting generation using Stable Diffusion.
452
+
453
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
454
+ library implements for all the pipelines (such as downloading or saving, running on a particular xxxx, etc.)
455
+
456
+ Args:
457
+ vae ([`AutoencoderKL`]):
458
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
459
+ text_encoder ([`CLIPTextModel`]):
460
+ Frozen text-encoder. Stable Diffusion uses the text portion of
461
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
462
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
463
+ tokenizer (`CLIPTokenizer`):
464
+ Tokenizer of class
465
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
466
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
467
+ scheduler ([`SchedulerMixin`]):
468
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
469
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`PNDMScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`]
470
+ or [`DPMSolverMultistepScheduler`].
471
+ safety_checker ([`StableDiffusionSafetyChecker`]):
472
+ Classification module that estimates whether generated images could be considered offensive or harmful.
473
+ Please, refer to the [model card](https://huggingface.co/junnyu/stable-diffusion-v1-4-paddle) for details.
474
+ feature_extractor ([`CLIPFeatureExtractor`]):
475
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
476
+ """
477
+ _optional_components = ["safety_checker", "feature_extractor"]
478
+
479
+ def __init__(
480
+ self,
481
+ vae: AutoencoderKL,
482
+ text_encoder: CLIPTextModel,
483
+ tokenizer: CLIPTokenizer,
484
+ unet: UNet2DConditionModel,
485
+ scheduler: Union[
486
+ DDIMScheduler,
487
+ PNDMScheduler,
488
+ LMSDiscreteScheduler,
489
+ EulerDiscreteScheduler,
490
+ EulerAncestralDiscreteScheduler,
491
+ DPMSolverMultistepScheduler,
492
+ ],
493
+ safety_checker: StableDiffusionSafetyChecker,
494
+ feature_extractor: CLIPFeatureExtractor,
495
+ requires_safety_checker: bool = False,
496
+ ):
497
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
498
+ deprecation_message = (
499
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
500
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
501
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
502
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
503
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
504
+ " file"
505
+ )
506
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
507
+ new_config = dict(scheduler.config)
508
+ new_config["steps_offset"] = 1
509
+ scheduler._internal_dict = FrozenDict(new_config)
510
+
511
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
512
+ deprecation_message = (
513
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
514
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
515
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
516
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
517
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
518
+ )
519
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
520
+ new_config = dict(scheduler.config)
521
+ new_config["clip_sample"] = False
522
+ scheduler._internal_dict = FrozenDict(new_config)
523
+
524
+ if safety_checker is None and requires_safety_checker:
525
+ logger.warning(
526
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
527
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
528
+ " results in services or applications open to the public. PaddleNLP team, diffusers team and Hugging Face"
529
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
530
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
531
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
532
+ )
533
+ if safety_checker is not None and feature_extractor is None:
534
+ raise ValueError(
535
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
536
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
537
+ )
538
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_ppdiffusers_version") and version.parse(
539
+ version.parse(unet.config._ppdiffusers_version).base_version
540
+ ) < version.parse("0.9.0.dev0")
541
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
542
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
543
+ deprecation_message = (
544
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
545
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
546
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
547
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
548
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
549
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
550
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
551
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
552
+ " the `unet/config.json` file"
553
+ )
554
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
555
+ new_config = dict(unet.config)
556
+ new_config["sample_size"] = 64
557
+ unet._internal_dict = FrozenDict(new_config)
558
+
559
+ self.register_modules(
560
+ vae=vae,
561
+ text_encoder=text_encoder,
562
+ tokenizer=tokenizer,
563
+ unet=unet,
564
+ scheduler=scheduler,
565
+ safety_checker=safety_checker,
566
+ feature_extractor=feature_extractor,
567
+ )
568
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
569
+
570
+ def create_scheduler(self, name="DPMSolver"):
571
+ config = self.scheduler.config
572
+ if name == "DPMSolver":
573
+ return DPMSolverMultistepScheduler.from_config(
574
+ config,
575
+ thresholding=False,
576
+ algorithm_type="dpmsolver++",
577
+ solver_type="midpoint",
578
+ lower_order_final=True,
579
+ )
580
+ if name == "EulerDiscrete":
581
+ return EulerDiscreteScheduler.from_config(config)
582
+ elif name == "EulerAncestralDiscrete":
583
+ return EulerAncestralDiscreteScheduler.from_config(config)
584
+ elif name == "PNDM":
585
+ return PNDMScheduler.from_config(config)
586
+ elif name == "DDIM":
587
+ return DDIMScheduler.from_config(config)
588
+ elif name == "LMSDiscrete":
589
+ return LMSDiscreteScheduler.from_config(config)
590
+ elif name == "HeunDiscrete":
591
+ return HeunDiscreteScheduler.from_config(config)
592
+ elif name == "KDPM2AncestralDiscrete":
593
+ return KDPM2AncestralDiscreteScheduler.from_config(config)
594
+ elif name == "KDPM2Discrete":
595
+ return KDPM2DiscreteScheduler.from_config(config)
596
+ else:
597
+ raise NotImplementedError
598
+
599
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
600
+ r"""
601
+ Enable sliced attention computation.
602
+
603
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
604
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
605
+
606
+ Args:
607
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
608
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
609
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
610
+ `attention_head_dim` must be a multiple of `slice_size`.
611
+ """
612
+ if slice_size == "auto":
613
+ if isinstance(self.unet.config.attention_head_dim, int):
614
+ # half the attention head size is usually a good trade-off between
615
+ # speed and memory
616
+ slice_size = self.unet.config.attention_head_dim // 2
617
+ else:
618
+ # if `attention_head_dim` is a list, take the smallest head size
619
+ slice_size = min(self.unet.config.attention_head_dim)
620
+ self.unet.set_attention_slice(slice_size)
621
+
622
+ def disable_attention_slicing(self):
623
+ r"""
624
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
625
+ back to computing attention in one step.
626
+ """
627
+ # set slice_size = `None` to disable `attention slicing`
628
+ self.enable_attention_slicing(None)
629
+
630
+ def __call__(self, *args, **kwargs):
631
+ return self.text2image(*args, **kwargs)
632
+
633
+ def text2img(self, *args, **kwargs):
634
+ return self.text2image(*args, **kwargs)
635
+
636
+ def _encode_prompt(
637
+ self,
638
+ prompt,
639
+ negative_prompt,
640
+ max_embeddings_multiples,
641
+ no_boseos_middle,
642
+ skip_parsing,
643
+ skip_weighting,
644
+ do_classifier_free_guidance,
645
+ num_images_per_prompt,
646
+ ):
647
+ if do_classifier_free_guidance and negative_prompt is None:
648
+ negative_prompt = ""
649
+ text_embeddings = get_weighted_text_embeddings(
650
+ self, prompt, negative_prompt, max_embeddings_multiples, no_boseos_middle, skip_parsing, skip_weighting
651
+ )
652
+
653
+ bs_embed, seq_len, _ = text_embeddings.shape
654
+ text_embeddings = text_embeddings.tile([1, num_images_per_prompt, 1])
655
+ text_embeddings = text_embeddings.reshape([bs_embed * num_images_per_prompt, seq_len, -1])
656
+ return text_embeddings
657
+
658
+ def run_safety_checker(self, image, dtype):
659
+ if self.safety_checker is not None:
660
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pd")
661
+ image, has_nsfw_concept = self.safety_checker(
662
+ images=image, clip_input=safety_checker_input.pixel_values.cast(dtype)
663
+ )
664
+ else:
665
+ has_nsfw_concept = None
666
+ return image, has_nsfw_concept
667
+
668
+ def decode_latents(self, latents):
669
+ latents = 1 / 0.18215 * latents
670
+ image = self.vae.decode(latents).sample
671
+ image = (image / 2 + 0.5).clip(0, 1)
672
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
673
+ image = image.transpose([0, 2, 3, 1]).cast("float32").numpy()
674
+ return image
675
+
676
+ def prepare_extra_step_kwargs(self, eta, scheduler):
677
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
678
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
679
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
680
+ # and should be between [0, 1]
681
+
682
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
683
+ extra_step_kwargs = {}
684
+ if accepts_eta:
685
+ extra_step_kwargs["eta"] = eta
686
+
687
+ return extra_step_kwargs
688
+
689
+ def check_inputs_text2img(self, prompt, height, width, callback_steps):
690
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
691
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
692
+
693
+ if height % 8 != 0 or width % 8 != 0:
694
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
695
+
696
+ if (callback_steps is None) or (
697
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
698
+ ):
699
+ raise ValueError(
700
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
701
+ f" {type(callback_steps)}."
702
+ )
703
+
704
+ def check_inputs_img2img_inpaint(self, prompt, strength, callback_steps):
705
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
706
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
707
+
708
+ if strength < 0 or strength > 1:
709
+ raise ValueError(f"The value of strength should in [1.0, 1.0] but is {strength}")
710
+
711
+ if (callback_steps is None) or (
712
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
713
+ ):
714
+ raise ValueError(
715
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
716
+ f" {type(callback_steps)}."
717
+ )
718
+
719
+ def prepare_latents_text2img(self, batch_size, num_channels_latents, height, width, dtype, latents=None, scheduler=None):
720
+ shape = [batch_size, num_channels_latents, height // 8, width // 8]
721
+ if latents is None:
722
+ latents = paddle.randn(shape, dtype=dtype)
723
+ else:
724
+ if latents.shape != shape:
725
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
726
+
727
+ # scale the initial noise by the standard deviation required by the scheduler
728
+ latents = latents * scheduler.init_noise_sigma
729
+ return latents
730
+
731
+ def prepare_latents_img2img(self, image, timestep, num_images_per_prompt, dtype, scheduler):
732
+ image = image.cast(dtype=dtype)
733
+ init_latent_dist = self.vae.encode(image).latent_dist
734
+ init_latents = init_latent_dist.sample()
735
+ init_latents = 0.18215 * init_latents
736
+
737
+ b, c, h, w = init_latents.shape
738
+ init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
739
+ init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
740
+
741
+ # add noise to latents using the timesteps
742
+ noise = paddle.randn(init_latents.shape, dtype=dtype)
743
+
744
+ # get latents
745
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
746
+ latents = init_latents
747
+
748
+ return latents
749
+
750
+ def get_timesteps(self, num_inference_steps, strength, scheduler):
751
+ # get the original timestep using init_timestep
752
+ offset = scheduler.config.get("steps_offset", 0)
753
+ init_timestep = int(num_inference_steps * strength) + offset
754
+ init_timestep = min(init_timestep, num_inference_steps)
755
+
756
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
757
+ timesteps = scheduler.timesteps[t_start:]
758
+
759
+ return timesteps, num_inference_steps - t_start
760
+
761
+ def prepare_latents_inpaint(self, image, timestep, num_images_per_prompt, dtype, scheduler):
762
+ image = image.cast(dtype)
763
+ init_latent_dist = self.vae.encode(image).latent_dist
764
+ init_latents = init_latent_dist.sample()
765
+ init_latents = 0.18215 * init_latents
766
+
767
+ b, c, h, w = init_latents.shape
768
+ init_latents = init_latents.tile([1, num_images_per_prompt, 1, 1])
769
+ init_latents = init_latents.reshape([b * num_images_per_prompt, c, h, w])
770
+
771
+ init_latents_orig = init_latents
772
+
773
+ # add noise to latents using the timesteps
774
+ noise = paddle.randn(init_latents.shape, dtype=dtype)
775
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
776
+ latents = init_latents
777
+ return latents, init_latents_orig, noise
778
+
779
+ @paddle.no_grad()
780
+ def text2image(
781
+ self,
782
+ prompt: Union[str, List[str]],
783
+ height: int = 512,
784
+ width: int = 512,
785
+ num_inference_steps: int = 50,
786
+ guidance_scale: float = 7.5,
787
+ negative_prompt: Optional[Union[str, List[str]]] = None,
788
+ num_images_per_prompt: Optional[int] = 1,
789
+ eta: float = 0.0,
790
+ seed: Optional[int] = None,
791
+ latents: Optional[paddle.Tensor] = None,
792
+ output_type: Optional[str] = "pil",
793
+ return_dict: bool = True,
794
+ callback: Optional[Callable[[int, int, paddle.Tensor], None]] = None,
795
+ callback_steps: Optional[int] = 1,
796
+ # new add
797
+ max_embeddings_multiples: Optional[int] = 1,
798
+ no_boseos_middle: Optional[bool] = False,
799
+ skip_parsing: Optional[bool] = False,
800
+ skip_weighting: Optional[bool] = False,
801
+ scheduler=None,
802
+ **kwargs,
803
+ ):
804
+ r"""
805
+ Function invoked when calling the pipeline for generation.
806
+
807
+ Args:
808
+ prompt (`str` or `List[str]`):
809
+ The prompt or prompts to guide the image generation.
810
+ height (`int`, *optional*, defaults to 512):
811
+ The height in pixels of the generated image.
812
+ width (`int`, *optional*, defaults to 512):
813
+ The width in pixels of the generated image.
814
+ num_inference_steps (`int`, *optional*, defaults to 50):
815
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
816
+ expense of slower inference.
817
+ guidance_scale (`float`, *optional*, defaults to 7.5):
818
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
819
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
820
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
821
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
822
+ usually at the expense of lower image quality.
823
+ negative_prompt (`str` or `List[str]`, *optional*):
824
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
825
+ if `guidance_scale` is less than `1`).
826
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
827
+ The number of images to generate per prompt.
828
+ eta (`float`, *optional*, defaults to 0.0):
829
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
830
+ [`schedulers.DDIMScheduler`], will be ignored for others.
831
+ seed (`int`, *optional*):
832
+ Random number seed.
833
+ latents (`paddle.Tensor`, *optional*):
834
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
835
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
836
+ tensor will ge generated by sampling using the supplied random `seed`.
837
+ output_type (`str`, *optional*, defaults to `"pil"`):
838
+ The output format of the generate image. Choose between
839
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
840
+ return_dict (`bool`, *optional*, defaults to `True`):
841
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
842
+ plain tuple.
843
+ callback (`Callable`, *optional*):
844
+ A function that will be called every `callback_steps` steps during inference. The function will be
845
+ called with the following arguments: `callback(step: int, timestep: int, latents: paddle.Tensor)`.
846
+ callback_steps (`int`, *optional*, defaults to 1):
847
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
848
+ called at every step.
849
+
850
+ Returns:
851
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
852
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
853
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
854
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
855
+ (nsfw) content, according to the `safety_checker`.
856
+ """
857
+ if scheduler is None:
858
+ scheduler = self.scheduler
859
+ seed = random.randint(0, 2**32) if seed is None else seed
860
+ argument = dict(
861
+ prompt=prompt,
862
+ negative_prompt=negative_prompt,
863
+ height=height,
864
+ width=width,
865
+ num_inference_steps=num_inference_steps,
866
+ guidance_scale=guidance_scale,
867
+ num_images_per_prompt=num_images_per_prompt,
868
+ eta=eta,
869
+ seed=seed,
870
+ latents=latents,
871
+ max_embeddings_multiples=max_embeddings_multiples,
872
+ no_boseos_middle=no_boseos_middle,
873
+ skip_parsing=skip_parsing,
874
+ skip_weighting=skip_weighting,
875
+ epoch_time=time.time(),
876
+ )
877
+ paddle.seed(seed)
878
+ # 1. Check inputs. Raise error if not correct
879
+ self.check_inputs_text2img(prompt, height, width, callback_steps)
880
+
881
+ # 2. Define call parameters
882
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
883
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
884
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
885
+ # corresponds to doing no classifier free guidance.
886
+ do_classifier_free_guidance = guidance_scale > 1.0
887
+
888
+ # 3. Encode input prompt
889
+ text_embeddings = self._encode_prompt(
890
+ prompt,
891
+ negative_prompt,
892
+ max_embeddings_multiples,
893
+ no_boseos_middle,
894
+ skip_parsing,
895
+ skip_weighting,
896
+ do_classifier_free_guidance,
897
+ num_images_per_prompt,
898
+ )
899
+
900
+ # 4. Prepare timesteps
901
+ scheduler.set_timesteps(num_inference_steps)
902
+ timesteps = scheduler.timesteps
903
+
904
+ # 5. Prepare latent variables
905
+ num_channels_latents = self.unet.in_channels
906
+ latents = self.prepare_latents_text2img(
907
+ batch_size * num_images_per_prompt,
908
+ num_channels_latents,
909
+ height,
910
+ width,
911
+ text_embeddings.dtype,
912
+ latents,
913
+ scheduler=scheduler,
914
+ )
915
+
916
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
917
+ extra_step_kwargs = self.prepare_extra_step_kwargs(eta, scheduler)
918
+
919
+ # 7. Denoising loop
920
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order
921
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
922
+ for i, t in enumerate(timesteps):
923
+ # expand the latents if we are doing classifier free guidance
924
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
925
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
926
+
927
+ # predict the noise residual
928
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
929
+
930
+