Maitreya Patel commited on
Commit
0c83406
1 Parent(s): df1b27d

initial setup

Browse files
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import pathlib
3
+ import gradio as gr
4
+ import torch
5
+ import os
6
+ import PIL
7
+ import torchvision.transforms as T
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import cv2
11
+ import matplotlib.pyplot as plt
12
+ from typing import Any
13
+
14
+ from transformers import (
15
+ CLIPTextModelWithProjection,
16
+ CLIPVisionModelWithProjection,
17
+ CLIPImageProcessor,
18
+ CLIPTokenizer
19
+ )
20
+
21
+ from transformers import CLIPTokenizer
22
+ from src.priors.lambda_prior_transformer import (
23
+ PriorTransformer,
24
+ ) # original huggingface prior transformer without time conditioning
25
+ from src.pipelines.pipeline_kandinsky_subject_prior import KandinskyPriorPipeline
26
+
27
+ from diffusers import DiffusionPipeline
28
+ from PIL import Image
29
+
30
+ class Model:
31
+ def __init__(self):
32
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+
34
+ self.text_encoder = (
35
+ CLIPTextModelWithProjection.from_pretrained(
36
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
37
+ projection_dim=1280,
38
+ torch_dtype=torch.float16,
39
+ )
40
+ .eval()
41
+ .requires_grad_(False)
42
+ ).to("cuda")
43
+
44
+ self.tokenizer = CLIPTokenizer.from_pretrained(
45
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
46
+ )
47
+
48
+ prior = PriorTransformer.from_pretrained(
49
+ "ECLIPSE-Community/Lambda-ECLIPSE-Prior-v1.0",
50
+ torch_dtype=torch.float16,
51
+ )
52
+
53
+ self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
54
+ "kandinsky-community/kandinsky-2-2-prior",
55
+ prior=prior,
56
+ torch_dtype=torch.float16,
57
+ ).to(self.device)
58
+
59
+ self.pipe = DiffusionPipeline.from_pretrained(
60
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
61
+ ).to(self.device)
62
+
63
+ def inference(self, raw_data):
64
+ image_emb, negative_image_emb = self.pipe_prior(
65
+ raw_data=raw_data,
66
+ ).to_tuple()
67
+ image = self.pipe(
68
+ image_embeds=image_emb,
69
+ negative_image_embeds=negative_image_emb,
70
+ num_inference_steps=50,
71
+ guidance_scale=4.0,
72
+ ).images[0]
73
+ return image
74
+
75
+ def process_data(self,
76
+ image: PIL.Image.Image,
77
+ keyword: str,
78
+ image2: PIL.Image.Image,
79
+ keyword2: str,
80
+ text: str,
81
+ ) -> dict[str, Any]:
82
+ print(f"keyword : {keyword}, keyword2 : {keyword2}, prompt : {text}")
83
+ device = torch.device(self.device)
84
+ data: dict[str, Any] = {}
85
+ data['text'] = text
86
+
87
+ txt = self.tokenizer(
88
+ text,
89
+ padding='max_length',
90
+ truncation=True,
91
+ return_tensors='pt',
92
+ )
93
+ txt_items = {k: v.to(device) for k, v in txt.items()}
94
+ new_feats = self.text_encoder(**txt_items)
95
+ new_last_hidden_states = new_feats.last_hidden_state[0].cpu().numpy()
96
+
97
+ plt.imshow(image)
98
+ plt.title('image')
99
+ plt.savefig('image_testt2.png')
100
+ plt.show()
101
+
102
+ mask_img = self.image_processor(image, return_tensors="pt").to("cuda")
103
+ vision_feats = self.vision_encoder(
104
+ **mask_img
105
+ ).image_embeds
106
+
107
+ entity_tokens = self.tokenizer(keyword)["input_ids"][1:-1]
108
+ for tid in entity_tokens:
109
+ indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
110
+ new_last_hidden_states[indices] = vision_feats[0].cpu().numpy()
111
+ print(indices)
112
+
113
+ if image2 is not None:
114
+ mask_img2 = self.image_processor(image2, return_tensors="pt").to("cuda")
115
+ vision_feats2 = self.vision_encoder(
116
+ **mask_img2
117
+ ).image_embeds
118
+ if keyword2 is not None:
119
+ entity_tokens = self.tokenizer(keyword2)["input_ids"][1:-1]
120
+ for tid in entity_tokens:
121
+ indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
122
+ new_last_hidden_states[indices] = vision_feats2[0].cpu().numpy()
123
+ print(indices)
124
+
125
+ text_feats = {
126
+ "prompt_embeds": new_feats.text_embeds.to("cuda"),
127
+ "text_encoder_hidden_states": torch.tensor(new_last_hidden_states).unsqueeze(0).to("cuda"),
128
+ "text_mask": txt_items["attention_mask"].to("cuda"),
129
+ }
130
+ return text_feats
131
+
132
+ def run(self,
133
+ image: dict[str, PIL.Image.Image],
134
+ keyword: str,
135
+ image2: dict[str, PIL.Image.Image],
136
+ keyword2: str,
137
+ text: str,
138
+ ):
139
+
140
+ # aug_feats = self.process_data(image["composite"], keyword, image2["composite"], keyword2, text)
141
+ sub_imgs = [image["composite"]]
142
+ if image2:
143
+ sub_imgs.append(image2["composite"])
144
+ sun_keywords = [keyword]
145
+ if keyword2:
146
+ sun_keywords.append(keyword2)
147
+ raw_data = {
148
+ "prompt": text,
149
+ "subject_images": sub_imgs,
150
+ "subject_keywords": sun_keywords
151
+ }
152
+ image = self.inference(raw_data)
153
+ return image
154
+
155
+ def create_demo():
156
+ TITLE = '# [λ-Eclipse Demo](https://eclipse-t2i.github.io/Lambda-ECLIPSE/)'
157
+
158
+ USAGE = '''To run the demo, you should:
159
+ 1. Upload your image.
160
+ 2. <span style='color: red;'>**Upload a masked subject image with white blankspace or whiten out manually using brush tool.**
161
+ 3. Input a Keyword i.e. 'Dog'
162
+ 4. For MultiSubject personalization,
163
+ 4-1. Upload another image.
164
+ 4-2. Input the Keyword i.e. 'Sunglasses'
165
+ 3. Input proper text prompts, such as "A photo of Dog" or "A Dog wearing sunglasses", Please use the same keyword in the prompt.
166
+ 4. Click the Run button.
167
+ '''
168
+
169
+ model = Model()
170
+
171
+ with gr.Blocks() as demo:
172
+ gr.Markdown(TITLE)
173
+ gr.Markdown(USAGE)
174
+ with gr.Row():
175
+ with gr.Column():
176
+ with gr.Group():
177
+ gr.Markdown(
178
+ 'Upload your first masked subject image or mask out marginal space')
179
+ image = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
180
+ keyword = gr.Text(
181
+ label='Keyword',
182
+ placeholder='e.g. "Dog", "Goofie"',
183
+ info='Keyword for first subject')
184
+ gr.Markdown(
185
+ 'For Multi-Subject generation : Upload your second masked subject image or mask out marginal space')
186
+ image2 = gr.ImageEditor(label='Input', type='pil', brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
187
+ keyword2= gr.Text(
188
+ label='Keyword',
189
+ placeholder='e.g. "Sunglasses", "Grand Canyon"',
190
+ info='Keyword for second subject')
191
+ prompt = gr.Text(
192
+ label='Prompt',
193
+ placeholder='e.g. "A photo of dog", "A dog wearing sunglasses"',
194
+ info='Keep the keywords used previously in the prompt')
195
+
196
+ run_button = gr.Button('Run')
197
+
198
+ with gr.Column():
199
+ result = gr.Image(label='Result')
200
+
201
+ inputs = [
202
+ image,
203
+ keyword,
204
+ image2,
205
+ keyword2,
206
+ prompt,
207
+ ]
208
+
209
+ gr.Examples(
210
+ examples=[[os.path.join(os.path.dirname(__file__), "./assets/cat.png"), "cat", os.path.join(os.path.dirname(__file__), "./assets/blue_sunglasses.png"), "glasses", "A cat wearing glasses on a snowy field"]],
211
+ inputs = inputs,
212
+ fn=model.run,
213
+ outputs=result,
214
+ )
215
+
216
+ run_button.click(fn=model.run, inputs=inputs, outputs=result)
217
+ return demo
218
+
219
+
220
+ if __name__ == '__main__':
221
+ demo = create_demo()
222
+ demo.queue(api_open=False).launch(share=True)
assets/a_cat_on_top_of_the_snow_mountain.png ADDED
assets/a_cat_wearing_glasses_at_a_park.png ADDED
assets/overview_white.png ADDED
assets/results.png ADDED
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ datasets
3
+ diffusers==0.24.0
4
+ numpy==1.26.1
5
+ packaging==23.2
6
+ pandas_stubs==1.2.0.57
7
+ Pillow==10.1.0
8
+ torch==2.0.0
9
+ torchvision==0.15.1
10
+ tqdm==4.66.1
11
+ transformers
12
+ gradio
13
+ jmespath
14
+ opencv-python
15
+ PyWavelet
16
+ gradio
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (193 Bytes). View file
 
src/pipelines/__pycache__/pipeline_kandinsky_subject_prior.cpython-39.pyc ADDED
Binary file (18.1 kB). View file
 
src/pipelines/pipeline_kandinsky_prior.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from transformers import (
8
+ CLIPImageProcessor,
9
+ CLIPTextModelWithProjection,
10
+ CLIPTokenizer,
11
+ CLIPVisionModelWithProjection,
12
+ )
13
+
14
+ from diffusers.models import PriorTransformer
15
+ from diffusers.schedulers import UnCLIPScheduler
16
+ from diffusers.utils import (
17
+ BaseOutput,
18
+ is_accelerate_available,
19
+ is_accelerate_version,
20
+ logging,
21
+ randn_tensor,
22
+ replace_example_docstring,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```py
32
+ >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
33
+ >>> import torch
34
+
35
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
36
+ >>> pipe_prior.to("cuda")
37
+
38
+ >>> prompt = "red cat, 4k photo"
39
+ >>> out = pipe_prior(prompt)
40
+ >>> image_emb = out.image_embeds
41
+ >>> negative_image_emb = out.negative_image_embeds
42
+
43
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
44
+ >>> pipe.to("cuda")
45
+
46
+ >>> image = pipe(
47
+ ... prompt,
48
+ ... image_embeds=image_emb,
49
+ ... negative_image_embeds=negative_image_emb,
50
+ ... height=768,
51
+ ... width=768,
52
+ ... num_inference_steps=100,
53
+ ... ).images
54
+
55
+ >>> image[0].save("cat.png")
56
+ ```
57
+ """
58
+
59
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
60
+ Examples:
61
+ ```py
62
+ >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
63
+ >>> from diffusers.utils import load_image
64
+ >>> import PIL
65
+
66
+ >>> import torch
67
+ >>> from torchvision import transforms
68
+
69
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
70
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
71
+ ... )
72
+ >>> pipe_prior.to("cuda")
73
+
74
+ >>> img1 = load_image(
75
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
76
+ ... "/kandinsky/cat.png"
77
+ ... )
78
+
79
+ >>> img2 = load_image(
80
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
81
+ ... "/kandinsky/starry_night.jpeg"
82
+ ... )
83
+
84
+ >>> images_texts = ["a cat", img1, img2]
85
+ >>> weights = [0.3, 0.3, 0.4]
86
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
87
+
88
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
89
+ >>> pipe.to("cuda")
90
+
91
+ >>> image = pipe(
92
+ ... "",
93
+ ... image_embeds=image_emb,
94
+ ... negative_image_embeds=zero_image_emb,
95
+ ... height=768,
96
+ ... width=768,
97
+ ... num_inference_steps=150,
98
+ ... ).images[0]
99
+
100
+ >>> image.save("starry_cat.png")
101
+ ```
102
+ """
103
+
104
+
105
+ @dataclass
106
+ class KandinskyPriorPipelineOutput(BaseOutput):
107
+ """
108
+ Output class for KandinskyPriorPipeline.
109
+
110
+ Args:
111
+ image_embeds (`torch.FloatTensor`)
112
+ clip image embeddings for text prompt
113
+ negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
114
+ clip image embeddings for unconditional tokens
115
+ """
116
+
117
+ image_embeds: Union[torch.FloatTensor, np.ndarray]
118
+ negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
119
+
120
+
121
+ class KandinskyPriorPipeline(DiffusionPipeline):
122
+ """
123
+ Pipeline for generating image prior for Kandinsky
124
+
125
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
126
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
127
+
128
+ Args:
129
+ prior ([`PriorTransformer`]):
130
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
131
+ image_encoder ([`CLIPVisionModelWithProjection`]):
132
+ Frozen image-encoder.
133
+ text_encoder ([`CLIPTextModelWithProjection`]):
134
+ Frozen text-encoder.
135
+ tokenizer (`CLIPTokenizer`):
136
+ Tokenizer of class
137
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
138
+ scheduler ([`UnCLIPScheduler`]):
139
+ A scheduler to be used in combination with `prior` to generate image embedding.
140
+ """
141
+
142
+ _exclude_from_cpu_offload = ["prior"]
143
+
144
+ def __init__(
145
+ self,
146
+ prior: PriorTransformer,
147
+ image_encoder: CLIPVisionModelWithProjection,
148
+ text_encoder: CLIPTextModelWithProjection,
149
+ tokenizer: CLIPTokenizer,
150
+ scheduler: UnCLIPScheduler,
151
+ image_processor: CLIPImageProcessor,
152
+ ):
153
+ super().__init__()
154
+
155
+ self.register_modules(
156
+ prior=prior,
157
+ text_encoder=text_encoder,
158
+ tokenizer=tokenizer,
159
+ scheduler=scheduler,
160
+ image_encoder=image_encoder,
161
+ image_processor=image_processor,
162
+ )
163
+
164
+ @torch.no_grad()
165
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
166
+ def interpolate(
167
+ self,
168
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
169
+ weights: List[float],
170
+ num_images_per_prompt: int = 1,
171
+ num_inference_steps: int = 25,
172
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
+ latents: Optional[torch.FloatTensor] = None,
174
+ negative_prior_prompt: Optional[str] = None,
175
+ negative_prompt: str = "",
176
+ guidance_scale: float = 4.0,
177
+ device=None,
178
+ ):
179
+ """
180
+ Function invoked when using the prior pipeline for interpolation.
181
+
182
+ Args:
183
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
184
+ list of prompts and images to guide the image generation.
185
+ weights: (`List[float]`):
186
+ list of weights for each condition in `images_and_prompts`
187
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
188
+ The number of images to generate per prompt.
189
+ num_inference_steps (`int`, *optional*, defaults to 25):
190
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191
+ expense of slower inference.
192
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
193
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
194
+ to make generation deterministic.
195
+ latents (`torch.FloatTensor`, *optional*):
196
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
197
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
198
+ tensor will ge generated by sampling using the supplied random `generator`.
199
+ negative_prior_prompt (`str`, *optional*):
200
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
201
+ `guidance_scale` is less than `1`).
202
+ negative_prompt (`str` or `List[str]`, *optional*):
203
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
204
+ `guidance_scale` is less than `1`).
205
+ guidance_scale (`float`, *optional*, defaults to 4.0):
206
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
207
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
208
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
209
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
210
+ usually at the expense of lower image quality.
211
+
212
+ Examples:
213
+
214
+ Returns:
215
+ [`KandinskyPriorPipelineOutput`] or `tuple`
216
+ """
217
+
218
+ device = device or self.device
219
+
220
+ if len(images_and_prompts) != len(weights):
221
+ raise ValueError(
222
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
223
+ )
224
+
225
+ image_embeddings = []
226
+ for cond, weight in zip(images_and_prompts, weights):
227
+ if isinstance(cond, str):
228
+ image_emb = self(
229
+ cond,
230
+ num_inference_steps=num_inference_steps,
231
+ num_images_per_prompt=num_images_per_prompt,
232
+ generator=generator,
233
+ latents=latents,
234
+ negative_prompt=negative_prior_prompt,
235
+ guidance_scale=guidance_scale,
236
+ ).image_embeds
237
+
238
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
239
+ if isinstance(cond, PIL.Image.Image):
240
+ cond = (
241
+ self.image_processor(cond, return_tensors="pt")
242
+ .pixel_values[0]
243
+ .unsqueeze(0)
244
+ .to(dtype=self.image_encoder.dtype, device=device)
245
+ )
246
+
247
+ image_emb = self.image_encoder(cond)["image_embeds"]
248
+
249
+ else:
250
+ raise ValueError(
251
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
252
+ )
253
+
254
+ image_embeddings.append(image_emb * weight)
255
+
256
+ image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
257
+
258
+ out_zero = self(
259
+ negative_prompt,
260
+ num_inference_steps=num_inference_steps,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ generator=generator,
263
+ latents=latents,
264
+ negative_prompt=negative_prior_prompt,
265
+ guidance_scale=guidance_scale,
266
+ )
267
+ zero_image_emb = (
268
+ out_zero.negative_image_embeds
269
+ if negative_prompt == ""
270
+ else out_zero.image_embeds
271
+ )
272
+
273
+ return KandinskyPriorPipelineOutput(
274
+ image_embeds=image_emb, negative_image_embeds=zero_image_emb
275
+ )
276
+
277
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
278
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
279
+ if latents is None:
280
+ latents = randn_tensor(
281
+ shape, generator=generator, device=device, dtype=dtype
282
+ )
283
+ else:
284
+ if latents.shape != shape:
285
+ raise ValueError(
286
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
287
+ )
288
+ latents = latents.to(device)
289
+
290
+ latents = latents * scheduler.init_noise_sigma
291
+ return latents
292
+
293
+ def get_zero_embed(self, batch_size=1, device=None):
294
+ device = device or self.device
295
+ zero_img = torch.zeros(
296
+ 1,
297
+ 3,
298
+ self.image_encoder.config.image_size,
299
+ self.image_encoder.config.image_size,
300
+ ).to(device=device, dtype=self.image_encoder.dtype)
301
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
302
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
303
+ return zero_image_emb
304
+
305
+ def _encode_prompt(
306
+ self,
307
+ prompt,
308
+ device,
309
+ num_images_per_prompt,
310
+ do_classifier_free_guidance,
311
+ negative_prompt=None,
312
+ ):
313
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
314
+ # get prompt text embeddings
315
+ text_inputs = self.tokenizer(
316
+ prompt,
317
+ padding="max_length",
318
+ max_length=self.tokenizer.model_max_length,
319
+ truncation=True,
320
+ return_tensors="pt",
321
+ )
322
+ text_input_ids = text_inputs.input_ids
323
+ text_mask = text_inputs.attention_mask.bool().to(device)
324
+
325
+ untruncated_ids = self.tokenizer(
326
+ prompt, padding="longest", return_tensors="pt"
327
+ ).input_ids
328
+
329
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
330
+ text_input_ids, untruncated_ids
331
+ ):
332
+ removed_text = self.tokenizer.batch_decode(
333
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
334
+ )
335
+ logger.warning(
336
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
337
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
338
+ )
339
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
340
+
341
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
342
+
343
+ prompt_embeds = text_encoder_output.text_embeds
344
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
345
+
346
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
347
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
348
+ num_images_per_prompt, dim=0
349
+ )
350
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
351
+
352
+ if do_classifier_free_guidance:
353
+ uncond_tokens: List[str]
354
+ if negative_prompt is None:
355
+ uncond_tokens = [""] * batch_size
356
+ elif type(prompt) is not type(negative_prompt):
357
+ raise TypeError(
358
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
359
+ f" {type(prompt)}."
360
+ )
361
+ elif isinstance(negative_prompt, str):
362
+ uncond_tokens = [negative_prompt]
363
+ elif batch_size != len(negative_prompt):
364
+ raise ValueError(
365
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
366
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
367
+ " the batch size of `prompt`."
368
+ )
369
+ else:
370
+ uncond_tokens = negative_prompt
371
+
372
+ uncond_input = self.tokenizer(
373
+ uncond_tokens,
374
+ padding="max_length",
375
+ max_length=self.tokenizer.model_max_length,
376
+ truncation=True,
377
+ return_tensors="pt",
378
+ )
379
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
380
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
381
+ uncond_input.input_ids.to(device)
382
+ )
383
+
384
+ negative_prompt_embeds = (
385
+ negative_prompt_embeds_text_encoder_output.text_embeds
386
+ )
387
+ uncond_text_encoder_hidden_states = (
388
+ negative_prompt_embeds_text_encoder_output.last_hidden_state
389
+ )
390
+
391
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
+
393
+ seq_len = negative_prompt_embeds.shape[1]
394
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
395
+ 1, num_images_per_prompt
396
+ )
397
+ negative_prompt_embeds = negative_prompt_embeds.view(
398
+ batch_size * num_images_per_prompt, seq_len
399
+ )
400
+
401
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
402
+ uncond_text_encoder_hidden_states = (
403
+ uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
404
+ )
405
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
406
+ batch_size * num_images_per_prompt, seq_len, -1
407
+ )
408
+ uncond_text_mask = uncond_text_mask.repeat_interleave(
409
+ num_images_per_prompt, dim=0
410
+ )
411
+
412
+ # done duplicates
413
+
414
+ # For classifier free guidance, we need to do two forward passes.
415
+ # Here we concatenate the unconditional and text embeddings into a single batch
416
+ # to avoid doing two forward passes
417
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
418
+ text_encoder_hidden_states = torch.cat(
419
+ [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
420
+ )
421
+
422
+ text_mask = torch.cat([uncond_text_mask, text_mask])
423
+
424
+ return prompt_embeds, text_encoder_hidden_states, text_mask
425
+
426
+ def enable_model_cpu_offload(self, gpu_id=0):
427
+ r"""
428
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
429
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
430
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
431
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
432
+ """
433
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
434
+ from accelerate import cpu_offload_with_hook
435
+ else:
436
+ raise ImportError(
437
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
438
+ )
439
+
440
+ device = torch.device(f"cuda:{gpu_id}")
441
+
442
+ if self.device.type != "cpu":
443
+ self.to("cpu", silence_dtype_warnings=True)
444
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
445
+
446
+ hook = None
447
+ for cpu_offloaded_model in [self.text_encoder, self.prior]:
448
+ _, hook = cpu_offload_with_hook(
449
+ cpu_offloaded_model, device, prev_module_hook=hook
450
+ )
451
+
452
+ # We'll offload the last model manually.
453
+ self.prior_hook = hook
454
+
455
+ _, hook = cpu_offload_with_hook(
456
+ self.image_encoder, device, prev_module_hook=self.prior_hook
457
+ )
458
+
459
+ self.final_offload_hook = hook
460
+
461
+ @torch.no_grad()
462
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
463
+ def __call__(
464
+ self,
465
+ prompt: Union[str, List[str]],
466
+ negative_prompt: Optional[Union[str, List[str]]] = None,
467
+ num_images_per_prompt: int = 1,
468
+ num_inference_steps: int = 25,
469
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
470
+ latents: Optional[torch.FloatTensor] = None,
471
+ guidance_scale: float = 4.0,
472
+ output_type: Optional[str] = "pt",
473
+ return_dict: bool = True,
474
+ ):
475
+ """
476
+ Function invoked when calling the pipeline for generation.
477
+
478
+ Args:
479
+ prompt (`str` or `List[str]`):
480
+ The prompt or prompts to guide the image generation.
481
+ negative_prompt (`str` or `List[str]`, *optional*):
482
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
483
+ if `guidance_scale` is less than `1`).
484
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
485
+ The number of images to generate per prompt.
486
+ num_inference_steps (`int`, *optional*, defaults to 25):
487
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
488
+ expense of slower inference.
489
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
490
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
491
+ to make generation deterministic.
492
+ latents (`torch.FloatTensor`, *optional*):
493
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
494
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
495
+ tensor will ge generated by sampling using the supplied random `generator`.
496
+ guidance_scale (`float`, *optional*, defaults to 4.0):
497
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
498
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
499
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
500
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
501
+ usually at the expense of lower image quality.
502
+ output_type (`str`, *optional*, defaults to `"pt"`):
503
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
504
+ (`torch.Tensor`).
505
+ return_dict (`bool`, *optional*, defaults to `True`):
506
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
507
+
508
+ Examples:
509
+
510
+ Returns:
511
+ [`KandinskyPriorPipelineOutput`] or `tuple`
512
+ """
513
+
514
+ if isinstance(prompt, str):
515
+ prompt = [prompt]
516
+ elif not isinstance(prompt, list):
517
+ raise ValueError(
518
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
519
+ )
520
+
521
+ if isinstance(negative_prompt, str):
522
+ negative_prompt = [negative_prompt]
523
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
524
+ raise ValueError(
525
+ f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}"
526
+ )
527
+
528
+ # if the negative prompt is defined we double the batch size to
529
+ # directly retrieve the negative prompt embedding
530
+ if negative_prompt is not None:
531
+ prompt = prompt + negative_prompt
532
+ negative_prompt = 2 * negative_prompt
533
+
534
+ device = self._execution_device
535
+
536
+ batch_size = len(prompt)
537
+ batch_size = batch_size * num_images_per_prompt
538
+
539
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
540
+ prompt, device, num_images_per_prompt, False, negative_prompt
541
+ )
542
+
543
+ hidden_states = randn_tensor(
544
+ (batch_size, prompt_embeds.shape[-1]),
545
+ device=prompt_embeds.device,
546
+ dtype=prompt_embeds.dtype,
547
+ generator=generator,
548
+ )
549
+
550
+ latents = self.prior(
551
+ hidden_states,
552
+ proj_embedding=prompt_embeds,
553
+ encoder_hidden_states=text_encoder_hidden_states,
554
+ attention_mask=text_mask,
555
+ ).predicted_image_embedding
556
+
557
+ image_embeddings = latents
558
+
559
+ # if negative prompt has been defined, we retrieve split the image embedding into two
560
+ if negative_prompt is None:
561
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
562
+
563
+ if (
564
+ hasattr(self, "final_offload_hook")
565
+ and self.final_offload_hook is not None
566
+ ):
567
+ self.final_offload_hook.offload()
568
+ else:
569
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
570
+
571
+ if (
572
+ hasattr(self, "final_offload_hook")
573
+ and self.final_offload_hook is not None
574
+ ):
575
+ self.prior_hook.offload()
576
+
577
+ if output_type not in ["pt", "np"]:
578
+ raise ValueError(
579
+ f"Only the output types `pt` and `np` are supported not output_type={output_type}"
580
+ )
581
+
582
+ if output_type == "np":
583
+ image_embeddings = image_embeddings.cpu().numpy()
584
+ zero_embeds = zero_embeds.cpu().numpy()
585
+
586
+ if not return_dict:
587
+ return (image_embeddings, zero_embeds)
588
+
589
+ return KandinskyPriorPipelineOutput(
590
+ image_embeds=image_embeddings, negative_image_embeds=zero_embeds
591
+ )
src/pipelines/pipeline_kandinsky_subject_prior.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+ import PIL
7
+ import torch
8
+ from transformers import (
9
+ CLIPImageProcessor,
10
+ CLIPTextModelWithProjection,
11
+ CLIPTokenizer,
12
+ CLIPVisionModelWithProjection,
13
+ )
14
+
15
+ from diffusers.models import PriorTransformer
16
+ from diffusers.schedulers import UnCLIPScheduler
17
+ from diffusers.utils import (
18
+ BaseOutput,
19
+ is_accelerate_available,
20
+ is_accelerate_version,
21
+ logging,
22
+ # randn_tensor,
23
+ replace_example_docstring,
24
+ )
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+ EXAMPLE_DOC_STRING = """
32
+ Examples:
33
+ ```py
34
+ >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
35
+ >>> import torch
36
+
37
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
38
+ >>> pipe_prior.to("cuda")
39
+
40
+ >>> prompt = "red cat, 4k photo"
41
+ >>> out = pipe_prior(prompt)
42
+ >>> image_emb = out.image_embeds
43
+ >>> negative_image_emb = out.negative_image_embeds
44
+
45
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
46
+ >>> pipe.to("cuda")
47
+
48
+ >>> image = pipe(
49
+ ... prompt,
50
+ ... image_embeds=image_emb,
51
+ ... negative_image_embeds=negative_image_emb,
52
+ ... height=768,
53
+ ... width=768,
54
+ ... num_inference_steps=100,
55
+ ... ).images
56
+
57
+ >>> image[0].save("cat.png")
58
+ ```
59
+ """
60
+
61
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
65
+ >>> from diffusers.utils import load_image
66
+ >>> import PIL
67
+
68
+ >>> import torch
69
+ >>> from torchvision import transforms
70
+
71
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
72
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
73
+ ... )
74
+ >>> pipe_prior.to("cuda")
75
+
76
+ >>> img1 = load_image(
77
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
78
+ ... "/kandinsky/cat.png"
79
+ ... )
80
+
81
+ >>> img2 = load_image(
82
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
83
+ ... "/kandinsky/starry_night.jpeg"
84
+ ... )
85
+
86
+ >>> images_texts = ["a cat", img1, img2]
87
+ >>> weights = [0.3, 0.3, 0.4]
88
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
89
+
90
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
91
+ >>> pipe.to("cuda")
92
+
93
+ >>> image = pipe(
94
+ ... "",
95
+ ... image_embeds=image_emb,
96
+ ... negative_image_embeds=zero_image_emb,
97
+ ... height=768,
98
+ ... width=768,
99
+ ... num_inference_steps=150,
100
+ ... ).images[0]
101
+
102
+ >>> image.save("starry_cat.png")
103
+ ```
104
+ """
105
+
106
+
107
+ @dataclass
108
+ class KandinskyPriorPipelineOutput(BaseOutput):
109
+ """
110
+ Output class for KandinskyPriorPipeline.
111
+
112
+ Args:
113
+ image_embeds (`torch.FloatTensor`)
114
+ clip image embeddings for text prompt
115
+ negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
116
+ clip image embeddings for unconditional tokens
117
+ """
118
+
119
+ image_embeds: Union[torch.FloatTensor, np.ndarray]
120
+ negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
121
+
122
+
123
+ class KandinskyPriorPipeline(DiffusionPipeline):
124
+ """
125
+ Pipeline for generating image prior for Kandinsky
126
+
127
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
128
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
129
+
130
+ Args:
131
+ prior ([`PriorTransformer`]):
132
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
133
+ image_encoder ([`CLIPVisionModelWithProjection`]):
134
+ Frozen image-encoder.
135
+ text_encoder ([`CLIPTextModelWithProjection`]):
136
+ Frozen text-encoder.
137
+ tokenizer (`CLIPTokenizer`):
138
+ Tokenizer of class
139
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
140
+ scheduler ([`UnCLIPScheduler`]):
141
+ A scheduler to be used in combination with `prior` to generate image embedding.
142
+ """
143
+
144
+ _exclude_from_cpu_offload = ["prior"]
145
+
146
+ def __init__(
147
+ self,
148
+ prior: PriorTransformer,
149
+ image_encoder: CLIPVisionModelWithProjection,
150
+ text_encoder: CLIPTextModelWithProjection,
151
+ tokenizer: CLIPTokenizer,
152
+ scheduler: UnCLIPScheduler,
153
+ image_processor: CLIPImageProcessor,
154
+ ):
155
+ super().__init__()
156
+
157
+ self.register_modules(
158
+ prior=prior,
159
+ text_encoder=text_encoder,
160
+ tokenizer=tokenizer,
161
+ scheduler=scheduler,
162
+ image_encoder=image_encoder,
163
+ image_processor=image_processor,
164
+ )
165
+
166
+ @torch.no_grad()
167
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
168
+ def interpolate(
169
+ self,
170
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
171
+ weights: List[float],
172
+ num_images_per_prompt: int = 1,
173
+ num_inference_steps: int = 25,
174
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
175
+ latents: Optional[torch.FloatTensor] = None,
176
+ negative_prior_prompt: Optional[str] = None,
177
+ negative_prompt: str = "",
178
+ guidance_scale: float = 4.0,
179
+ device=None,
180
+ ):
181
+ """
182
+ Function invoked when using the prior pipeline for interpolation.
183
+
184
+ Args:
185
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
186
+ list of prompts and images to guide the image generation.
187
+ weights: (`List[float]`):
188
+ list of weights for each condition in `images_and_prompts`
189
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
190
+ The number of images to generate per prompt.
191
+ num_inference_steps (`int`, *optional*, defaults to 25):
192
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
193
+ expense of slower inference.
194
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
195
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
196
+ to make generation deterministic.
197
+ latents (`torch.FloatTensor`, *optional*):
198
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
199
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
200
+ tensor will ge generated by sampling using the supplied random `generator`.
201
+ negative_prior_prompt (`str`, *optional*):
202
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
203
+ `guidance_scale` is less than `1`).
204
+ negative_prompt (`str` or `List[str]`, *optional*):
205
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
206
+ `guidance_scale` is less than `1`).
207
+ guidance_scale (`float`, *optional*, defaults to 4.0):
208
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
209
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
210
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
211
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
212
+ usually at the expense of lower image quality.
213
+
214
+ Examples:
215
+
216
+ Returns:
217
+ [`KandinskyPriorPipelineOutput`] or `tuple`
218
+ """
219
+
220
+ device = device or self.device
221
+
222
+ if len(images_and_prompts) != len(weights):
223
+ raise ValueError(
224
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
225
+ )
226
+
227
+ image_embeddings = []
228
+ for cond, weight in zip(images_and_prompts, weights):
229
+ if isinstance(cond, str):
230
+ image_emb = self(
231
+ cond,
232
+ num_inference_steps=num_inference_steps,
233
+ num_images_per_prompt=num_images_per_prompt,
234
+ generator=generator,
235
+ latents=latents,
236
+ negative_prompt=negative_prior_prompt,
237
+ guidance_scale=guidance_scale,
238
+ ).image_embeds
239
+
240
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
241
+ if isinstance(cond, PIL.Image.Image):
242
+ cond = (
243
+ self.image_processor(cond, return_tensors="pt")
244
+ .pixel_values[0]
245
+ .unsqueeze(0)
246
+ .to(dtype=self.image_encoder.dtype, device=device)
247
+ )
248
+
249
+ image_emb = self.image_encoder(cond)["image_embeds"]
250
+
251
+ else:
252
+ raise ValueError(
253
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
254
+ )
255
+
256
+ image_embeddings.append(image_emb * weight)
257
+
258
+ image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
259
+
260
+ out_zero = self(
261
+ negative_prompt,
262
+ num_inference_steps=num_inference_steps,
263
+ num_images_per_prompt=num_images_per_prompt,
264
+ generator=generator,
265
+ latents=latents,
266
+ negative_prompt=negative_prior_prompt,
267
+ guidance_scale=guidance_scale,
268
+ )
269
+ zero_image_emb = (
270
+ out_zero.negative_image_embeds
271
+ if negative_prompt == ""
272
+ else out_zero.image_embeds
273
+ )
274
+
275
+ return KandinskyPriorPipelineOutput(
276
+ image_embeds=image_emb, negative_image_embeds=zero_image_emb
277
+ )
278
+
279
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
280
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
281
+ if latents is None:
282
+ latents = randn_tensor(
283
+ shape, generator=generator, device=device, dtype=dtype
284
+ )
285
+ else:
286
+ if latents.shape != shape:
287
+ raise ValueError(
288
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
289
+ )
290
+ latents = latents.to(device)
291
+
292
+ latents = latents * scheduler.init_noise_sigma
293
+ return latents
294
+
295
+ def get_zero_embed(self, batch_size=1, device=None):
296
+ device = device or self.device
297
+ zero_img = torch.zeros(
298
+ 1,
299
+ 3,
300
+ self.image_encoder.config.image_size,
301
+ self.image_encoder.config.image_size,
302
+ ).to(device=device, dtype=self.image_encoder.dtype)
303
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
304
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
305
+ return zero_image_emb
306
+
307
+ def _encode_prompt(
308
+ self,
309
+ prompt,
310
+ device,
311
+ num_images_per_prompt,
312
+ do_classifier_free_guidance,
313
+ negative_prompt=None,
314
+ ):
315
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
316
+ # get prompt text embeddings
317
+ text_inputs = self.tokenizer(
318
+ prompt,
319
+ padding="max_length",
320
+ max_length=self.tokenizer.model_max_length,
321
+ truncation=True,
322
+ return_tensors="pt",
323
+ )
324
+ text_input_ids = text_inputs.input_ids
325
+ text_mask = text_inputs.attention_mask.bool().to(device)
326
+
327
+ untruncated_ids = self.tokenizer(
328
+ prompt, padding="longest", return_tensors="pt"
329
+ ).input_ids
330
+
331
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
332
+ text_input_ids, untruncated_ids
333
+ ):
334
+ removed_text = self.tokenizer.batch_decode(
335
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
336
+ )
337
+ logger.warning(
338
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
339
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
340
+ )
341
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
342
+
343
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
344
+
345
+ prompt_embeds = text_encoder_output.text_embeds
346
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
347
+
348
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
349
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
350
+ num_images_per_prompt, dim=0
351
+ )
352
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
353
+
354
+ if do_classifier_free_guidance:
355
+ uncond_tokens: List[str]
356
+ if negative_prompt is None:
357
+ uncond_tokens = [""] * batch_size
358
+ elif type(prompt) is not type(negative_prompt):
359
+ raise TypeError(
360
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
361
+ f" {type(prompt)}."
362
+ )
363
+ elif isinstance(negative_prompt, str):
364
+ uncond_tokens = [negative_prompt]
365
+ elif batch_size != len(negative_prompt):
366
+ raise ValueError(
367
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
368
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
369
+ " the batch size of `prompt`."
370
+ )
371
+ else:
372
+ uncond_tokens = negative_prompt
373
+
374
+ uncond_input = self.tokenizer(
375
+ uncond_tokens,
376
+ padding="max_length",
377
+ max_length=self.tokenizer.model_max_length,
378
+ truncation=True,
379
+ return_tensors="pt",
380
+ )
381
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
382
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
383
+ uncond_input.input_ids.to(device)
384
+ )
385
+
386
+ negative_prompt_embeds = (
387
+ negative_prompt_embeds_text_encoder_output.text_embeds
388
+ )
389
+ uncond_text_encoder_hidden_states = (
390
+ negative_prompt_embeds_text_encoder_output.last_hidden_state
391
+ )
392
+
393
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
394
+
395
+ seq_len = negative_prompt_embeds.shape[1]
396
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
397
+ 1, num_images_per_prompt
398
+ )
399
+ negative_prompt_embeds = negative_prompt_embeds.view(
400
+ batch_size * num_images_per_prompt, seq_len
401
+ )
402
+
403
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
404
+ uncond_text_encoder_hidden_states = (
405
+ uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
406
+ )
407
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
408
+ batch_size * num_images_per_prompt, seq_len, -1
409
+ )
410
+ uncond_text_mask = uncond_text_mask.repeat_interleave(
411
+ num_images_per_prompt, dim=0
412
+ )
413
+
414
+ # done duplicates
415
+
416
+ # For classifier free guidance, we need to do two forward passes.
417
+ # Here we concatenate the unconditional and text embeddings into a single batch
418
+ # to avoid doing two forward passes
419
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
420
+ text_encoder_hidden_states = torch.cat(
421
+ [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
422
+ )
423
+
424
+ text_mask = torch.cat([uncond_text_mask, text_mask])
425
+
426
+ return prompt_embeds, text_encoder_hidden_states, text_mask
427
+
428
+ def enable_model_cpu_offload(self, gpu_id=0):
429
+ r"""
430
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
431
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
432
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
433
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
434
+ """
435
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
436
+ from accelerate import cpu_offload_with_hook
437
+ else:
438
+ raise ImportError(
439
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
440
+ )
441
+
442
+ device = torch.device(f"cuda:{gpu_id}")
443
+
444
+ if self.device.type != "cpu":
445
+ self.to("cpu", silence_dtype_warnings=True)
446
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
447
+
448
+ hook = None
449
+ for cpu_offloaded_model in [self.text_encoder, self.prior]:
450
+ _, hook = cpu_offload_with_hook(
451
+ cpu_offloaded_model, device, prev_module_hook=hook
452
+ )
453
+
454
+ # We'll offload the last model manually.
455
+ self.prior_hook = hook
456
+
457
+ _, hook = cpu_offload_with_hook(
458
+ self.image_encoder, device, prev_module_hook=self.prior_hook
459
+ )
460
+
461
+ self.final_offload_hook = hook
462
+
463
+ @torch.no_grad()
464
+ def get_text_feats(self, raw_data):
465
+ prompt = raw_data["prompt"]
466
+ txt = self.tokenizer(
467
+ prompt,
468
+ padding="max_length",
469
+ truncation=True,
470
+ return_tensors="pt",
471
+ )
472
+ txt_items = {k: v.to("cuda") for k, v in txt.items()}
473
+ txt_feats = self.text_encoder(**txt_items)
474
+ last_hidden_states = txt_feats.last_hidden_state[0].detach().cpu().numpy()
475
+ prompt_embeds = txt_feats.text_embeds.detach().cpu()
476
+ attention_mask = txt_items["attention_mask"]
477
+
478
+ for sub_img, sub_name in zip(raw_data["subject_images"], raw_data["subject_keywords"]):
479
+ if isinstance(sub_img, str):
480
+ sub_img = Image.open(sub_img)
481
+ mask_img = self.image_processor(sub_img, return_tensors="pt").to("cuda")
482
+ vision_feats = self.image_encoder(**mask_img).image_embeds
483
+ entity_tokens = self.tokenizer(sub_name)["input_ids"][1:-1]
484
+
485
+ found = True
486
+ for tid in entity_tokens:
487
+ indices = np.where(txt_items["input_ids"][0].cpu().numpy() == tid)[0]
488
+ if len(indices)==0:
489
+ found = False
490
+ last_hidden_states[indices] = vision_feats[0].cpu().numpy()
491
+
492
+ if not found:
493
+ print(f"Couldn't find keyword '{sub_name}' in the prompt.")
494
+
495
+ text_feats = {
496
+ "prompt_embeds": prompt_embeds,
497
+ "text_encoder_hidden_states": torch.tensor(last_hidden_states).unsqueeze(0),
498
+ "text_mask": attention_mask,
499
+ }
500
+
501
+ return text_feats
502
+
503
+ @torch.no_grad()
504
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
505
+ def __call__(
506
+ self,
507
+ text_feats: dict = None,
508
+ raw_data: dict = None,
509
+ num_images_per_prompt: int = 1,
510
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
511
+ latents: Optional[torch.FloatTensor] = None,
512
+ output_type: Optional[str] = "pt",
513
+ return_dict: bool = True,
514
+ control_embedding: torch.FloatTensor = None,
515
+ ):
516
+ """
517
+ Function invoked when calling the pipeline for generation.
518
+
519
+ Args:
520
+ text_feats (`dict`, *optional*, defaults to None):
521
+ "prompt_embeds", "text_encoder_hidden_states", "text_mask"
522
+ raw_data (`dict`, *optional*, defaults to None):
523
+ "prompt": str,
524
+ "subject_images": List of str or PIL
525
+ "subject_keywords": List of str
526
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
527
+ The number of images to generate per prompt.
528
+ num_inference_steps (`int`, *optional*, defaults to 25):
529
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
530
+ expense of slower inference.
531
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
532
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
533
+ to make generation deterministic.
534
+ latents (`torch.FloatTensor`, *optional*):
535
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
536
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
537
+ tensor will ge generated by sampling using the supplied random `generator`.
538
+ output_type (`str`, *optional*, defaults to `"pt"`):
539
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
540
+ (`torch.Tensor`).
541
+ return_dict (`bool`, *optional*, defaults to `True`):
542
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
543
+
544
+ Examples:
545
+
546
+ Returns:
547
+ [`KandinskyPriorPipelineOutput`] or `tuple`
548
+ """
549
+ assert text_feats or raw_data, "please provide wither raw_data or pre-processed text-feats"
550
+ assert num_images_per_prompt==1
551
+
552
+ if text_feats is None:
553
+ text_feats = self.get_text_feats(raw_data)
554
+
555
+ device = self._execution_device
556
+ for k,v in text_feats.items():
557
+ text_feats[k] = v.to(device)
558
+
559
+ if control_embedding is None:
560
+ control_embedding = self.get_zero_embed(1, device=device)
561
+
562
+ batch_size = text_feats["prompt_embeds"].shape[0]
563
+ assert batch_size == 1
564
+
565
+ batch_size = batch_size * num_images_per_prompt
566
+
567
+ prompt_embeds = text_feats["prompt_embeds"]
568
+ text_encoder_hidden_states = text_feats["text_encoder_hidden_states"]
569
+ text_mask = text_feats["text_mask"]
570
+
571
+ hidden_states = randn_tensor(
572
+ (batch_size, prompt_embeds.shape[-1]),
573
+ device=prompt_embeds.device,
574
+ dtype=prompt_embeds.dtype,
575
+ generator=generator,
576
+ )
577
+
578
+ latents = self.prior(
579
+ hidden_states,
580
+ proj_embedding=prompt_embeds,
581
+ encoder_hidden_states=text_encoder_hidden_states,
582
+ attention_mask=text_mask,
583
+ control_embedding=control_embedding,
584
+ ).predicted_image_embedding
585
+
586
+ image_embeddings = latents
587
+
588
+ # if negative prompt has been defined, we retrieve split the image embedding into two
589
+ negative_prompt = None
590
+ if negative_prompt is None:
591
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
592
+
593
+ if (
594
+ hasattr(self, "final_offload_hook")
595
+ and self.final_offload_hook is not None
596
+ ):
597
+ self.final_offload_hook.offload()
598
+ else:
599
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
600
+
601
+ if (
602
+ hasattr(self, "final_offload_hook")
603
+ and self.final_offload_hook is not None
604
+ ):
605
+ self.prior_hook.offload()
606
+
607
+ if output_type not in ["pt", "np"]:
608
+ raise ValueError(
609
+ f"Only the output types `pt` and `np` are supported not output_type={output_type}"
610
+ )
611
+
612
+ if output_type == "np":
613
+ image_embeddings = image_embeddings.cpu().numpy()
614
+ zero_embeds = zero_embeds.cpu().numpy()
615
+
616
+ if not return_dict:
617
+ return (image_embeddings, zero_embeds)
618
+
619
+ return KandinskyPriorPipelineOutput(
620
+ image_embeds=image_embeddings, negative_image_embeds=zero_embeds
621
+ )
src/priors/__init__.py ADDED
File without changes
src/priors/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (190 Bytes). View file
 
src/priors/__pycache__/lambda_prior_transformer.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
src/priors/lambda_prior_transformer.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Union
3
+
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.utils import BaseOutput
11
+ from diffusers.models.attention import BasicTransformerBlock
12
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
13
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ @dataclass
18
+ class PriorTransformerOutput(BaseOutput):
19
+ """
20
+ The output of [`PriorTransformer`].
21
+
22
+ Args:
23
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
24
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
25
+ """
26
+
27
+ predicted_image_embedding: torch.FloatTensor
28
+
29
+
30
+ class PriorTransformer(ModelMixin, ConfigMixin):
31
+ """
32
+ A Prior Transformer model.
33
+
34
+ Parameters:
35
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
36
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
37
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
38
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
39
+ num_embeddings (`int`, *optional*, defaults to 77):
40
+ The number of embeddings of the model input `hidden_states`
41
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
42
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
43
+ additional_embeddings`.
44
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
45
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
46
+ The activation function to use to create timestep embeddings.
47
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
48
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
49
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
50
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
51
+ needed.
52
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
53
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
54
+ `encoder_hidden_states` is `None`.
55
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
56
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
57
+ product between the text embedding and image embedding as proposed in the unclip paper
58
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
59
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
60
+ If None, will be set to `num_attention_heads * attention_head_dim`
61
+ embedding_proj_dim (`int`, *optional*, default to None):
62
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
63
+ clip_embed_dim (`int`, *optional*, default to None):
64
+ The dimension of the output. If None, will be set to `embedding_dim`.
65
+ """
66
+
67
+ @register_to_config
68
+ def __init__(
69
+ self,
70
+ num_attention_heads: int = 32,
71
+ attention_head_dim: int = 64,
72
+ num_layers: int = 20,
73
+ embedding_dim: int = 768,
74
+ num_embeddings=77,
75
+ additional_embeddings=3, # as we have remvoed the time embedding
76
+ dropout: float = 0.0,
77
+ # time_embed_act_fn: str = "silu",
78
+ norm_in_type: Optional[str] = None, # layer
79
+ embedding_proj_norm_type: Optional[str] = None, # layer
80
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
81
+ added_emb_type: Optional[str] = "prd", # prd
82
+ # time_embed_dim: Optional[int] = None,
83
+ embedding_proj_dim: Optional[int] = None,
84
+ clip_embed_dim: Optional[int] = None,
85
+ ):
86
+ super().__init__()
87
+ self.num_attention_heads = num_attention_heads
88
+ self.attention_head_dim = attention_head_dim
89
+ inner_dim = num_attention_heads * attention_head_dim
90
+ self.additional_embeddings = additional_embeddings
91
+
92
+ # time_embed_dim = time_embed_dim or inner_dim
93
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
94
+ clip_embed_dim = clip_embed_dim or embedding_dim
95
+
96
+ # self.time_proj = Timesteps(inner_dim, True, 0)
97
+ # self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
98
+
99
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
100
+
101
+ if embedding_proj_norm_type is None:
102
+ self.embedding_proj_norm = None
103
+ elif embedding_proj_norm_type == "layer":
104
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
105
+ else:
106
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
107
+
108
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
109
+ self.embedding_control = nn.Linear(embedding_proj_dim, inner_dim)
110
+
111
+ if encoder_hid_proj_type is None:
112
+ self.encoder_hidden_states_proj = None
113
+ elif encoder_hid_proj_type == "linear":
114
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
115
+ else:
116
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
117
+
118
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
119
+
120
+ if added_emb_type == "prd":
121
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
122
+ elif added_emb_type is None:
123
+ self.prd_embedding = None
124
+ else:
125
+ raise ValueError(
126
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
127
+ )
128
+
129
+ self.transformer_blocks = nn.ModuleList(
130
+ [
131
+ BasicTransformerBlock(
132
+ inner_dim,
133
+ num_attention_heads,
134
+ attention_head_dim,
135
+ dropout=dropout,
136
+ activation_fn="gelu",
137
+ attention_bias=True,
138
+ )
139
+ for d in range(num_layers)
140
+ ]
141
+ )
142
+
143
+ if norm_in_type == "layer":
144
+ self.norm_in = nn.LayerNorm(inner_dim)
145
+ elif norm_in_type is None:
146
+ self.norm_in = None
147
+ else:
148
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
149
+
150
+ self.norm_out = nn.LayerNorm(inner_dim)
151
+
152
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
153
+
154
+ causal_attention_mask = torch.full(
155
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
156
+ )
157
+ causal_attention_mask.triu_(1)
158
+ causal_attention_mask = causal_attention_mask[None, ...]
159
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
160
+
161
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
162
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
163
+
164
+ @property
165
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
166
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
167
+ r"""
168
+ Returns:
169
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
170
+ indexed by its weight name.
171
+ """
172
+ # set recursively
173
+ processors = {}
174
+
175
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
176
+ if hasattr(module, "set_processor"):
177
+ processors[f"{name}.processor"] = module.processor
178
+
179
+ for sub_name, child in module.named_children():
180
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
181
+
182
+ return processors
183
+
184
+ for name, module in self.named_children():
185
+ fn_recursive_add_processors(name, module, processors)
186
+
187
+ return processors
188
+
189
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
190
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
191
+ r"""
192
+ Sets the attention processor to use to compute attention.
193
+
194
+ Parameters:
195
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
196
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
197
+ for **all** `Attention` layers.
198
+
199
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
200
+ processor. This is strongly recommended when setting trainable attention processors.
201
+
202
+ """
203
+ count = len(self.attn_processors.keys())
204
+
205
+ if isinstance(processor, dict) and len(processor) != count:
206
+ raise ValueError(
207
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
208
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
209
+ )
210
+
211
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
212
+ if hasattr(module, "set_processor"):
213
+ if not isinstance(processor, dict):
214
+ module.set_processor(processor)
215
+ else:
216
+ module.set_processor(processor.pop(f"{name}.processor"))
217
+
218
+ for sub_name, child in module.named_children():
219
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
220
+
221
+ for name, module in self.named_children():
222
+ fn_recursive_attn_processor(name, module, processor)
223
+
224
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
225
+ def set_default_attn_processor(self):
226
+ """
227
+ Disables custom attention processors and sets the default attention implementation.
228
+ """
229
+ self.set_attn_processor(AttnProcessor())
230
+
231
+ def forward(
232
+ self,
233
+ hidden_states,
234
+ # timestep: Union[torch.Tensor, float, int],
235
+ proj_embedding: torch.FloatTensor,
236
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
237
+ attention_mask: Optional[torch.BoolTensor] = None,
238
+ control_embedding: torch.FloatTensor = None,
239
+ return_dict: bool = True,
240
+ ):
241
+ """
242
+ The [`PriorTransformer`] forward method.
243
+
244
+ Args:
245
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
246
+ The currently predicted image embeddings.
247
+ timestep (`torch.LongTensor`):
248
+ Current denoising step.
249
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
250
+ Projected embedding vector the denoising process is conditioned on.
251
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
252
+ Hidden states of the text embeddings the denoising process is conditioned on.
253
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
254
+ Text mask for the text embeddings.
255
+ return_dict (`bool`, *optional*, defaults to `True`):
256
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
257
+ tuple.
258
+
259
+ Returns:
260
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
261
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
262
+ tuple is returned where the first element is the sample tensor.
263
+ """
264
+ batch_size = hidden_states.shape[0]
265
+
266
+ # timesteps = timestep
267
+ # if not torch.is_tensor(timesteps):
268
+ # timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
269
+ # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
270
+ # timesteps = timesteps[None].to(hidden_states.device)
271
+
272
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
273
+ # timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
274
+
275
+ # timesteps_projected = self.time_proj(timesteps)
276
+
277
+ # timesteps does not contain any weights and will always return f32 tensors
278
+ # but time_embedding might be fp16, so we need to cast here.
279
+ # timesteps_projected = timesteps_projected.to(dtype=self.dtype)
280
+ # time_embeddings = self.time_embedding(timesteps_projected)
281
+
282
+ if self.embedding_proj_norm is not None:
283
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
284
+
285
+ proj_embeddings = self.embedding_proj(proj_embedding)
286
+ control_embeddings = self.embedding_control(control_embedding)
287
+
288
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
289
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
290
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
291
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
292
+
293
+ hidden_states = self.proj_in(hidden_states)
294
+
295
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
296
+
297
+ additional_embeds = []
298
+ additional_embeddings_len = 0
299
+
300
+ if encoder_hidden_states is not None:
301
+ additional_embeds.append(encoder_hidden_states)
302
+ additional_embeddings_len += encoder_hidden_states.shape[1]
303
+
304
+ if len(proj_embeddings.shape) == 2:
305
+ proj_embeddings = proj_embeddings[:, None, :]
306
+
307
+ if len(control_embeddings.shape) == 2:
308
+ control_embeddings = control_embeddings[:, None, :]
309
+
310
+ if len(hidden_states.shape) == 2:
311
+ hidden_states = hidden_states[:, None, :]
312
+
313
+ additional_embeds = additional_embeds + [
314
+ control_embeddings,
315
+ proj_embeddings,
316
+ # time_embeddings[:, None, :],
317
+ hidden_states,
318
+ ]
319
+
320
+ if self.prd_embedding is not None:
321
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
322
+ additional_embeds.append(prd_embedding)
323
+
324
+ hidden_states = torch.cat(
325
+ additional_embeds,
326
+ dim=1,
327
+ )
328
+
329
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
330
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
331
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
332
+ positional_embeddings = F.pad(
333
+ positional_embeddings,
334
+ (
335
+ 0,
336
+ 0,
337
+ additional_embeddings_len,
338
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
339
+ ),
340
+ value=0.0,
341
+ )
342
+
343
+ hidden_states = hidden_states + positional_embeddings
344
+
345
+ if attention_mask is not None:
346
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
347
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
348
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
349
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
350
+
351
+ if self.norm_in is not None:
352
+ hidden_states = self.norm_in(hidden_states)
353
+
354
+ for block in self.transformer_blocks:
355
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
356
+
357
+ hidden_states = self.norm_out(hidden_states)
358
+
359
+ if self.prd_embedding is not None:
360
+ hidden_states = hidden_states[:, -1]
361
+ else:
362
+ hidden_states = hidden_states[:, additional_embeddings_len:]
363
+
364
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
365
+
366
+ if not return_dict:
367
+ return (predicted_image_embedding,)
368
+
369
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
370
+
371
+ def post_process_latents(self, prior_latents):
372
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
373
+ return prior_latents
src/priors/prior_transformer.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Optional, Union
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.models.attention import BasicTransformerBlock
15
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+
19
+
20
+ @dataclass
21
+ class PriorTransformerOutput(BaseOutput):
22
+ """
23
+ The output of [`PriorTransformer`].
24
+
25
+ Args:
26
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
27
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
28
+ """
29
+
30
+ predicted_image_embedding: torch.FloatTensor
31
+
32
+
33
+ class PriorTransformer(ModelMixin, ConfigMixin):
34
+ """
35
+ A Prior Transformer model.
36
+
37
+ Parameters:
38
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
39
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
40
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
41
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
42
+ num_embeddings (`int`, *optional*, defaults to 77):
43
+ The number of embeddings of the model input `hidden_states`
44
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
46
+ additional_embeddings`.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
49
+ The activation function to use to create timestep embeddings.
50
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
51
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
52
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
53
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
54
+ needed.
55
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
56
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
57
+ `encoder_hidden_states` is `None`.
58
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
59
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
60
+ product between the text embedding and image embedding as proposed in the unclip paper
61
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
62
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
63
+ If None, will be set to `num_attention_heads * attention_head_dim`
64
+ embedding_proj_dim (`int`, *optional*, default to None):
65
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
66
+ clip_embed_dim (`int`, *optional*, default to None):
67
+ The dimension of the output. If None, will be set to `embedding_dim`.
68
+ """
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_attention_heads: int = 32,
74
+ attention_head_dim: int = 64,
75
+ num_layers: int = 20,
76
+ embedding_dim: int = 768,
77
+ num_embeddings=77,
78
+ additional_embeddings=3, # as we have remvoed the time embedding
79
+ dropout: float = 0.0,
80
+ # time_embed_act_fn: str = "silu",
81
+ norm_in_type: Optional[str] = None, # layer
82
+ embedding_proj_norm_type: Optional[str] = None, # layer
83
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
84
+ added_emb_type: Optional[str] = "prd", # prd
85
+ # time_embed_dim: Optional[int] = None,
86
+ embedding_proj_dim: Optional[int] = None,
87
+ clip_embed_dim: Optional[int] = None,
88
+ ):
89
+ super().__init__()
90
+ self.num_attention_heads = num_attention_heads
91
+ self.attention_head_dim = attention_head_dim
92
+ inner_dim = num_attention_heads * attention_head_dim
93
+ self.additional_embeddings = additional_embeddings
94
+
95
+ # time_embed_dim = time_embed_dim or inner_dim
96
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
97
+ clip_embed_dim = clip_embed_dim or embedding_dim
98
+
99
+ # self.time_proj = Timesteps(inner_dim, True, 0)
100
+ # self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
101
+
102
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
103
+
104
+ if embedding_proj_norm_type is None:
105
+ self.embedding_proj_norm = None
106
+ elif embedding_proj_norm_type == "layer":
107
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
108
+ else:
109
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
110
+
111
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
112
+
113
+ if encoder_hid_proj_type is None:
114
+ self.encoder_hidden_states_proj = None
115
+ elif encoder_hid_proj_type == "linear":
116
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
117
+ else:
118
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
119
+
120
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
121
+
122
+ if added_emb_type == "prd":
123
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
124
+ elif added_emb_type is None:
125
+ self.prd_embedding = None
126
+ else:
127
+ raise ValueError(
128
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
129
+ )
130
+
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [
133
+ BasicTransformerBlock(
134
+ inner_dim,
135
+ num_attention_heads,
136
+ attention_head_dim,
137
+ dropout=dropout,
138
+ activation_fn="gelu",
139
+ attention_bias=True,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+
145
+ if norm_in_type == "layer":
146
+ self.norm_in = nn.LayerNorm(inner_dim)
147
+ elif norm_in_type is None:
148
+ self.norm_in = None
149
+ else:
150
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
151
+
152
+ self.norm_out = nn.LayerNorm(inner_dim)
153
+
154
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
155
+
156
+ causal_attention_mask = torch.full(
157
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
158
+ )
159
+ causal_attention_mask.triu_(1)
160
+ causal_attention_mask = causal_attention_mask[None, ...]
161
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
162
+
163
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
164
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
165
+
166
+ @property
167
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
168
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
169
+ r"""
170
+ Returns:
171
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
172
+ indexed by its weight name.
173
+ """
174
+ # set recursively
175
+ processors = {}
176
+
177
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
178
+ if hasattr(module, "set_processor"):
179
+ processors[f"{name}.processor"] = module.processor
180
+
181
+ for sub_name, child in module.named_children():
182
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
183
+
184
+ return processors
185
+
186
+ for name, module in self.named_children():
187
+ fn_recursive_add_processors(name, module, processors)
188
+
189
+ return processors
190
+
191
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
192
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
193
+ r"""
194
+ Sets the attention processor to use to compute attention.
195
+
196
+ Parameters:
197
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
198
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
199
+ for **all** `Attention` layers.
200
+
201
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
202
+ processor. This is strongly recommended when setting trainable attention processors.
203
+
204
+ """
205
+ count = len(self.attn_processors.keys())
206
+
207
+ if isinstance(processor, dict) and len(processor) != count:
208
+ raise ValueError(
209
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
210
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
211
+ )
212
+
213
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
214
+ if hasattr(module, "set_processor"):
215
+ if not isinstance(processor, dict):
216
+ module.set_processor(processor)
217
+ else:
218
+ module.set_processor(processor.pop(f"{name}.processor"))
219
+
220
+ for sub_name, child in module.named_children():
221
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
222
+
223
+ for name, module in self.named_children():
224
+ fn_recursive_attn_processor(name, module, processor)
225
+
226
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
227
+ def set_default_attn_processor(self):
228
+ """
229
+ Disables custom attention processors and sets the default attention implementation.
230
+ """
231
+ self.set_attn_processor(AttnProcessor())
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states,
236
+ # timestep: Union[torch.Tensor, float, int],
237
+ proj_embedding: torch.FloatTensor,
238
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
239
+ attention_mask: Optional[torch.BoolTensor] = None,
240
+ return_dict: bool = True,
241
+ ):
242
+ """
243
+ The [`PriorTransformer`] forward method.
244
+
245
+ Args:
246
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
+ The currently predicted image embeddings.
248
+ timestep (`torch.LongTensor`):
249
+ Current denoising step.
250
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
251
+ Projected embedding vector the denoising process is conditioned on.
252
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
253
+ Hidden states of the text embeddings the denoising process is conditioned on.
254
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
255
+ Text mask for the text embeddings.
256
+ return_dict (`bool`, *optional*, defaults to `True`):
257
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
258
+ tuple.
259
+
260
+ Returns:
261
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
262
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
263
+ tuple is returned where the first element is the sample tensor.
264
+ """
265
+ batch_size = hidden_states.shape[0]
266
+
267
+ # timesteps = timestep
268
+ # if not torch.is_tensor(timesteps):
269
+ # timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
270
+ # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
271
+ # timesteps = timesteps[None].to(hidden_states.device)
272
+
273
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
274
+ # timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
275
+
276
+ # timesteps_projected = self.time_proj(timesteps)
277
+
278
+ # timesteps does not contain any weights and will always return f32 tensors
279
+ # but time_embedding might be fp16, so we need to cast here.
280
+ # timesteps_projected = timesteps_projected.to(dtype=self.dtype)
281
+ # time_embeddings = self.time_embedding(timesteps_projected)
282
+
283
+ if self.embedding_proj_norm is not None:
284
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
285
+
286
+ proj_embeddings = self.embedding_proj(proj_embedding)
287
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
288
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
289
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
290
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
291
+
292
+ hidden_states = self.proj_in(hidden_states)
293
+
294
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
295
+
296
+ additional_embeds = []
297
+ additional_embeddings_len = 0
298
+
299
+ if encoder_hidden_states is not None:
300
+ additional_embeds.append(encoder_hidden_states)
301
+ additional_embeddings_len += encoder_hidden_states.shape[1]
302
+
303
+ if len(proj_embeddings.shape) == 2:
304
+ proj_embeddings = proj_embeddings[:, None, :]
305
+
306
+ if len(hidden_states.shape) == 2:
307
+ hidden_states = hidden_states[:, None, :]
308
+
309
+ additional_embeds = additional_embeds + [
310
+ proj_embeddings,
311
+ # time_embeddings[:, None, :],
312
+ hidden_states,
313
+ ]
314
+
315
+ if self.prd_embedding is not None:
316
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
317
+ additional_embeds.append(prd_embedding)
318
+
319
+ hidden_states = torch.cat(
320
+ additional_embeds,
321
+ dim=1,
322
+ )
323
+
324
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
325
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
326
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
327
+ positional_embeddings = F.pad(
328
+ positional_embeddings,
329
+ (
330
+ 0,
331
+ 0,
332
+ additional_embeddings_len,
333
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
334
+ ),
335
+ value=0.0,
336
+ )
337
+
338
+ hidden_states = hidden_states + positional_embeddings
339
+
340
+ if attention_mask is not None:
341
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
342
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
343
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
344
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
345
+
346
+ if self.norm_in is not None:
347
+ hidden_states = self.norm_in(hidden_states)
348
+
349
+ for block in self.transformer_blocks:
350
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
351
+
352
+ hidden_states = self.norm_out(hidden_states)
353
+
354
+ if self.prd_embedding is not None:
355
+ hidden_states = hidden_states[:, -1]
356
+ else:
357
+ hidden_states = hidden_states[:, additional_embeddings_len:]
358
+
359
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
360
+
361
+ if not return_dict:
362
+ return (predicted_image_embedding,)
363
+
364
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
365
+
366
+ def post_process_latents(self, prior_latents):
367
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
368
+ return prior_latents