ghoskno commited on
Commit
a202a34
1 Parent(s): 0eecb3e

add: Color-Canny Controlnet demo

Browse files
Files changed (8) hide show
  1. .gitattributes +3 -0
  2. README.md +2 -0
  3. app.py +157 -0
  4. asserts/1.png +3 -0
  5. asserts/2.png +3 -0
  6. asserts/3.png +3 -0
  7. lpw.py +389 -0
  8. requirements.txt +7 -0
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ asserts/1.png filter=lfs diff=lfs merge=lfs -text
36
+ asserts/2.png filter=lfs diff=lfs merge=lfs -text
37
+ asserts/3.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -7,6 +7,8 @@ sdk: gradio
7
  sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
+ tags:
11
+ - jax-diffusers-event
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers import StableDiffusionControlNetPipeline, StableDiffusionLatentUpscalePipeline, ControlNetModel, AutoencoderKL
7
+ from diffusers import UniPCMultistepScheduler
8
+ from PIL import Image
9
+
10
+ from lpw import _encode_prompt
11
+
12
+ controlnet_ColorCanny = ControlNetModel.from_pretrained("ghoskno/Color-Canny-Controlnet-model", torch_dtype=torch.float16)
13
+
14
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
15
+
16
+ pipe = StableDiffusionControlNetPipeline.from_pretrained("Lykon/DreamShaper", vae=vae, controlnet=controlnet_ColorCanny, torch_dtype=torch.float16)
17
+
18
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
19
+ pipe.enable_model_cpu_offload()
20
+ pipe.enable_xformers_memory_efficient_attention()
21
+ pipe.enable_attention_slicing()
22
+
23
+ # Generator seed
24
+ generator = torch.manual_seed(0)
25
+
26
+ def HWC3(x):
27
+ assert x.dtype == np.uint8
28
+ if x.ndim == 2:
29
+ x = x[:, :, None]
30
+ assert x.ndim == 3
31
+ H, W, C = x.shape
32
+ assert C == 1 or C == 3 or C == 4
33
+ if C == 3:
34
+ return x
35
+ if C == 1:
36
+ return np.concatenate([x, x, x], axis=2)
37
+ if C == 4:
38
+ color = x[:, :, 0:3].astype(np.float32)
39
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
40
+ y = color * alpha + 255.0 * (1.0 - alpha)
41
+ y = y.clip(0, 255).astype(np.uint8)
42
+ return y
43
+
44
+ def resize_image(input_image, resolution, max_edge=False, edge_limit=False):
45
+ H, W, C = input_image.shape
46
+
47
+ H = float(H)
48
+ W = float(W)
49
+ if max_edge:
50
+ k = float(resolution) / max(H, W)
51
+ else:
52
+ k = float(resolution) / min(H, W)
53
+ H *= k
54
+ W *= k
55
+
56
+ H, W = int(H), int(W)
57
+
58
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
59
+ if not edge_limit:
60
+ return img
61
+ pH = int(np.round(H / 64.0)) * 64
62
+ pW = int(np.round(W / 64.0)) * 64
63
+ pimg = np.zeros((pH, pW, 3), dtype=img.dtype)
64
+
65
+ oH, oW = (pH-H)//2, (pW-W)//2
66
+ pimg[oH:oH+H, oW:oW+W] = img
67
+ return pimg
68
+
69
+ def get_canny_filter(image, format='pil', low_threshold=100, high_threshold=200):
70
+
71
+ if not isinstance(image, np.ndarray):
72
+ image = np.array(image)
73
+
74
+ image = cv2.Canny(image, low_threshold, high_threshold)
75
+ image = image[:, :, None]
76
+ image = np.concatenate([image, image, image], axis=2)
77
+ if format == 'pil':
78
+ image = Image.fromarray(image)
79
+ return image
80
+
81
+ def get_color_filter(cond_image, mask_size=64):
82
+ H, W = cond_image.shape[:2]
83
+ cond_image = cv2.resize(cond_image, (W // mask_size, H // mask_size), interpolation=cv2.INTER_CUBIC)
84
+ color = cv2.resize(cond_image, (W, H), interpolation=cv2.INTER_NEAREST)
85
+ return color
86
+
87
+ def get_colorcanny(image, mask_size):
88
+
89
+ if not isinstance(image, np.ndarray):
90
+ image = np.array(image)
91
+
92
+ canny_img = get_canny_filter(image, format='np')
93
+
94
+ color_img = get_color_filter(image, int(mask_size))
95
+
96
+ color_img[np.where(canny_img > 128)] = 255
97
+ color_img = Image.fromarray(color_img)
98
+ return color_img
99
+
100
+ def process(input_image, prompt, n_prompt, strength=1.0, color_mask_size=96, size=512, scale=6.0, ddim_steps=20):
101
+ prompt_embeds, negative_prompt_embeds = _encode_prompt(pipe, prompt, pipe.device, 1, True, n_prompt, 3)
102
+ input_image = resize_image(input_image, size, max_edge=True, edge_limit=True)
103
+
104
+ cond_img = get_colorcanny(input_image, color_mask_size)
105
+ output = pipe(
106
+ prompt_embeds=prompt_embeds,
107
+ negative_prompt_embeds=negative_prompt_embeds,
108
+ image=cond_img,
109
+ generator=generator,
110
+ num_images_per_prompt=1,
111
+ num_inference_steps=ddim_steps,
112
+ guidance_scale=scale,
113
+ controlnet_conditioning_scale=float(strength)
114
+ )
115
+ return [output.images[0], cond_img]
116
+
117
+
118
+ block = gr.Blocks().queue()
119
+
120
+ with block:
121
+ gr.Markdown("""
122
+ # Color-Canny-Controlnet
123
+
124
+ This is a demo on Controlnet based on Color & Canny
125
+ """)
126
+ with gr.Row():
127
+ with gr.Column():
128
+ input_image = gr.Image(source='upload', type="numpy")
129
+ prompt = gr.Textbox(label="Prompt", value='')
130
+ n_prompt = gr.Textbox(label="Negative Prompt", value='')
131
+ run_button = gr.Button(label="Run")
132
+ with gr.Accordion('Advanced', open=False):
133
+ strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
134
+ color_mask_size = gr.Slider(label="Color Mask Size", minimum=32, maximum=256, value=96, step=16)
135
+ size = gr.Slider(label="Size", minimum=256, maximum=768, value=512, step=128)
136
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=6.0, step=0.1)
137
+ ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
138
+
139
+ with gr.Column():
140
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
141
+ ips = [input_image, prompt, n_prompt, strength, color_mask_size, size, scale, ddim_steps]
142
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
143
+
144
+ gr.Examples(
145
+ examples=[
146
+ ["./asserts/1.png", "a concept art of by Makoto Shinkai, a girl is standing in the middle of the sea", "text, bad anatomy, blurry, (low quality, blurry)"],
147
+ ["./asserts/2.png", "a concept illustration with saturated vivid watercolors by Erin Hanson and Moebius stylized graphic scene", "text, bad anatomy, blurry, (low quality, blurry)"],
148
+ ["./asserts/3.png", "sky city on the sea, with waves churning and wind power plants on the island", "text, bad anatomy, blurry, (low quality, blurry)"],
149
+ ],
150
+ inputs=[
151
+ input_image, prompt, n_prompt
152
+ ],
153
+ outputs=result_gallery,
154
+ fn=process,
155
+ cache_examples=True,
156
+ )
157
+ block.launch(debug = True, server_name='0.0.0.0')
asserts/1.png ADDED

Git LFS Details

  • SHA256: e250800446f8dac17441de781b34cdb54fe1dbf783b25001d7c8751ccb21d766
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
asserts/2.png ADDED

Git LFS Details

  • SHA256: 4494d2cfd3f452398366a1a74a674994bd56d398b11746a762b817c80aeedc97
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
asserts/3.png ADDED

Git LFS Details

  • SHA256: 41cea92f93341ce3a0dae5a9f5ebeb063ddf5284062d243297d4183d2f50fdef
  • Pointer size: 132 Bytes
  • Size of remote file: 1.59 MB
lpw.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+
9
+ re_attention = re.compile(
10
+ r"""
11
+ \\\(|
12
+ \\\)|
13
+ \\\[|
14
+ \\]|
15
+ \\\\|
16
+ \\|
17
+ \(|
18
+ \[|
19
+ :([+-]?[.\d]+)\)|
20
+ \)|
21
+ ]|
22
+ [^\\()\[\]:]+|
23
+ :
24
+ """,
25
+ re.X,
26
+ )
27
+
28
+ def parse_prompt_attention(text):
29
+ """
30
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
31
+ Accepted tokens are:
32
+ (abc) - increases attention to abc by a multiplier of 1.1
33
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
34
+ [abc] - decreases attention to abc by a multiplier of 1.1
35
+ \( - literal character '('
36
+ \[ - literal character '['
37
+ \) - literal character ')'
38
+ \] - literal character ']'
39
+ \\ - literal character '\'
40
+ anything else - just text
41
+ >>> parse_prompt_attention('normal text')
42
+ [['normal text', 1.0]]
43
+ >>> parse_prompt_attention('an (important) word')
44
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
45
+ >>> parse_prompt_attention('(unbalanced')
46
+ [['unbalanced', 1.1]]
47
+ >>> parse_prompt_attention('\(literal\]')
48
+ [['(literal]', 1.0]]
49
+ >>> parse_prompt_attention('(unnecessary)(parens)')
50
+ [['unnecessaryparens', 1.1]]
51
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
52
+ [['a ', 1.0],
53
+ ['house', 1.5730000000000004],
54
+ [' ', 1.1],
55
+ ['on', 1.0],
56
+ [' a ', 1.1],
57
+ ['hill', 0.55],
58
+ [', sun, ', 1.1],
59
+ ['sky', 1.4641000000000006],
60
+ ['.', 1.1]]
61
+ """
62
+
63
+ res = []
64
+ round_brackets = []
65
+ square_brackets = []
66
+
67
+ round_bracket_multiplier = 1.1
68
+ square_bracket_multiplier = 1 / 1.1
69
+
70
+ def multiply_range(start_position, multiplier):
71
+ for p in range(start_position, len(res)):
72
+ res[p][1] *= multiplier
73
+
74
+ for m in re_attention.finditer(text):
75
+ text = m.group(0)
76
+ weight = m.group(1)
77
+
78
+ if text.startswith("\\"):
79
+ res.append([text[1:], 1.0])
80
+ elif text == "(":
81
+ round_brackets.append(len(res))
82
+ elif text == "[":
83
+ square_brackets.append(len(res))
84
+ elif weight is not None and len(round_brackets) > 0:
85
+ multiply_range(round_brackets.pop(), float(weight))
86
+ elif text == ")" and len(round_brackets) > 0:
87
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
88
+ elif text == "]" and len(square_brackets) > 0:
89
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
90
+ else:
91
+ res.append([text, 1.0])
92
+
93
+ for pos in round_brackets:
94
+ multiply_range(pos, round_bracket_multiplier)
95
+
96
+ for pos in square_brackets:
97
+ multiply_range(pos, square_bracket_multiplier)
98
+
99
+ if len(res) == 0:
100
+ res = [["", 1.0]]
101
+
102
+ # merge runs of identical weights
103
+ i = 0
104
+ while i + 1 < len(res):
105
+ if res[i][1] == res[i + 1][1]:
106
+ res[i][0] += res[i + 1][0]
107
+ res.pop(i + 1)
108
+ else:
109
+ i += 1
110
+
111
+ return res
112
+
113
+
114
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
115
+ r"""
116
+ Tokenize a list of prompts and return its tokens with weights of each token.
117
+
118
+ No padding, starting or ending token is included.
119
+ """
120
+ tokens = []
121
+ weights = []
122
+ truncated = False
123
+ for text in prompt:
124
+ texts_and_weights = parse_prompt_attention(text)
125
+ text_token = []
126
+ text_weight = []
127
+ for word, weight in texts_and_weights:
128
+ # tokenize and discard the starting and the ending token
129
+ token = pipe.tokenizer(word).input_ids[1:-1]
130
+ text_token += token
131
+ # copy the weight by length of token
132
+ text_weight += [weight] * len(token)
133
+ # stop if the text is too long (longer than truncation limit)
134
+ if len(text_token) > max_length:
135
+ truncated = True
136
+ break
137
+ # truncate
138
+ if len(text_token) > max_length:
139
+ truncated = True
140
+ text_token = text_token[:max_length]
141
+ text_weight = text_weight[:max_length]
142
+ tokens.append(text_token)
143
+ weights.append(text_weight)
144
+ if truncated:
145
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
146
+ return tokens, weights
147
+
148
+
149
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
150
+ r"""
151
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
152
+ """
153
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
154
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
155
+ for i in range(len(tokens)):
156
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
157
+ if no_boseos_middle:
158
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
159
+ else:
160
+ w = []
161
+ if len(weights[i]) == 0:
162
+ w = [1.0] * weights_length
163
+ else:
164
+ for j in range(max_embeddings_multiples):
165
+ w.append(1.0) # weight for starting token in this chunk
166
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
167
+ w.append(1.0) # weight for ending token in this chunk
168
+ w += [1.0] * (weights_length - len(w))
169
+ weights[i] = w[:]
170
+
171
+ return tokens, weights
172
+
173
+ def get_unweighted_text_embeddings(
174
+ pipe: StableDiffusionPipeline,
175
+ text_input: torch.Tensor,
176
+ chunk_length: int,
177
+ no_boseos_middle: Optional[bool] = True,
178
+ ):
179
+ """
180
+ When the length of tokens is a multiple of the capacity of the text encoder,
181
+ it should be split into chunks and sent to the text encoder individually.
182
+ """
183
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
184
+ if max_embeddings_multiples > 1:
185
+ text_embeddings = []
186
+ for i in range(max_embeddings_multiples):
187
+ # extract the i-th chunk
188
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
189
+
190
+ # cover the head and the tail by the starting and the ending tokens
191
+ text_input_chunk[:, 0] = text_input[0, 0]
192
+ text_input_chunk[:, -1] = text_input[0, -1]
193
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
194
+
195
+ if no_boseos_middle:
196
+ if i == 0:
197
+ # discard the ending token
198
+ text_embedding = text_embedding[:, :-1]
199
+ elif i == max_embeddings_multiples - 1:
200
+ # discard the starting token
201
+ text_embedding = text_embedding[:, 1:]
202
+ else:
203
+ # discard both starting and ending tokens
204
+ text_embedding = text_embedding[:, 1:-1]
205
+
206
+ text_embeddings.append(text_embedding)
207
+ text_embeddings = torch.concat(text_embeddings, axis=1)
208
+ else:
209
+ text_embeddings = pipe.text_encoder(text_input)[0]
210
+ return text_embeddings
211
+
212
+
213
+ def get_weighted_text_embeddings(
214
+ pipe: StableDiffusionPipeline,
215
+ prompt: Union[str, List[str]],
216
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
217
+ max_embeddings_multiples: Optional[int] = 3,
218
+ no_boseos_middle: Optional[bool] = False,
219
+ skip_parsing: Optional[bool] = False,
220
+ skip_weighting: Optional[bool] = False,
221
+ **kwargs,
222
+ ):
223
+ r"""
224
+ Prompts can be assigned with local weights using brackets. For example,
225
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
226
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
227
+
228
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
229
+
230
+ Args:
231
+ pipe (`StableDiffusionPipeline`):
232
+ Pipe to provide access to the tokenizer and the text encoder.
233
+ prompt (`str` or `List[str]`):
234
+ The prompt or prompts to guide the image generation.
235
+ uncond_prompt (`str` or `List[str]`):
236
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
237
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
238
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
239
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
240
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
241
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
242
+ ending token in each of the chunk in the middle.
243
+ skip_parsing (`bool`, *optional*, defaults to `False`):
244
+ Skip the parsing of brackets.
245
+ skip_weighting (`bool`, *optional*, defaults to `False`):
246
+ Skip the weighting. When the parsing is skipped, it is forced True.
247
+ """
248
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
249
+ if isinstance(prompt, str):
250
+ prompt = [prompt]
251
+
252
+ if not skip_parsing:
253
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
254
+ if uncond_prompt is not None:
255
+ if isinstance(uncond_prompt, str):
256
+ uncond_prompt = [uncond_prompt]
257
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
258
+ else:
259
+ prompt_tokens = [
260
+ token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
261
+ ]
262
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
263
+ if uncond_prompt is not None:
264
+ if isinstance(uncond_prompt, str):
265
+ uncond_prompt = [uncond_prompt]
266
+ uncond_tokens = [
267
+ token[1:-1]
268
+ for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
269
+ ]
270
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
271
+
272
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
273
+ max_length = max([len(token) for token in prompt_tokens])
274
+ if uncond_prompt is not None:
275
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
276
+
277
+ max_embeddings_multiples = min(
278
+ max_embeddings_multiples,
279
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
280
+ )
281
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
282
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
283
+
284
+ # pad the length of tokens and weights
285
+ bos = pipe.tokenizer.bos_token_id
286
+ eos = pipe.tokenizer.eos_token_id
287
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
288
+ prompt_tokens,
289
+ prompt_weights,
290
+ max_length,
291
+ bos,
292
+ eos,
293
+ no_boseos_middle=no_boseos_middle,
294
+ chunk_length=pipe.tokenizer.model_max_length,
295
+ )
296
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.text_encoder.device)
297
+ if uncond_prompt is not None:
298
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
299
+ uncond_tokens,
300
+ uncond_weights,
301
+ max_length,
302
+ bos,
303
+ eos,
304
+ no_boseos_middle=no_boseos_middle,
305
+ chunk_length=pipe.tokenizer.model_max_length,
306
+ )
307
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.text_encoder.device)
308
+
309
+ # get the embeddings
310
+ text_embeddings = get_unweighted_text_embeddings(
311
+ pipe,
312
+ prompt_tokens,
313
+ pipe.tokenizer.model_max_length,
314
+ no_boseos_middle=no_boseos_middle,
315
+ )
316
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.text_encoder.device)
317
+ if uncond_prompt is not None:
318
+ uncond_embeddings = get_unweighted_text_embeddings(
319
+ pipe,
320
+ uncond_tokens,
321
+ pipe.tokenizer.model_max_length,
322
+ no_boseos_middle=no_boseos_middle,
323
+ )
324
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.text_encoder.device)
325
+
326
+ # assign weights to the prompts and normalize in the sense of mean
327
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
328
+ if (not skip_parsing) and (not skip_weighting):
329
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
330
+ text_embeddings *= prompt_weights.unsqueeze(-1)
331
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
332
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
333
+ if uncond_prompt is not None:
334
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
335
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
336
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
337
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
338
+
339
+ if uncond_prompt is not None:
340
+ return text_embeddings, uncond_embeddings
341
+ return text_embeddings, None
342
+
343
+ def _encode_prompt(
344
+ pipe,
345
+ prompt,
346
+ device,
347
+ num_images_per_prompt,
348
+ do_classifier_free_guidance,
349
+ negative_prompt,
350
+ max_embeddings_multiples,
351
+ ):
352
+ r"""
353
+ Encodes the prompt into text encoder hidden states.
354
+
355
+ Args:
356
+ prompt (`str` or `list(int)`):
357
+ prompt to be encoded
358
+ device: (`torch.device`):
359
+ torch device
360
+ num_images_per_prompt (`int`):
361
+ number of images that should be generated per prompt
362
+ do_classifier_free_guidance (`bool`):
363
+ whether to use classifier free guidance or not
364
+ negative_prompt (`str` or `List[str]`):
365
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
366
+ if `guidance_scale` is less than `1`).
367
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
368
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
369
+ """
370
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
371
+
372
+ if negative_prompt is None:
373
+ negative_prompt = [""] * batch_size
374
+ elif isinstance(negative_prompt, str):
375
+ negative_prompt = [negative_prompt] * batch_size
376
+ if batch_size != len(negative_prompt):
377
+ raise ValueError(
378
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
379
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
380
+ " the batch size of `prompt`."
381
+ )
382
+
383
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
384
+ pipe=pipe,
385
+ prompt=prompt,
386
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
387
+ max_embeddings_multiples=max_embeddings_multiples,
388
+ )
389
+ return text_embeddings, uncond_embeddings
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ transformers
4
+ torch
5
+ xformers
6
+ safetensors
7
+ opencv-contrib-python