Maitreya Patel commited on
Commit
c5e8b9c
1 Parent(s): 6435502

demo files

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
main.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ import torch
5
+
6
+ from torchvision import transforms
7
+ from transformers import (
8
+ CLIPProcessor,
9
+ CLIPModel,
10
+ CLIPTokenizer,
11
+ CLIPTextModelWithProjection,
12
+ CLIPVisionModelWithProjection,
13
+ CLIPFeatureExtractor,
14
+ )
15
+
16
+ import math
17
+ from typing import List
18
+ from PIL import Image, ImageChops
19
+ import numpy as np
20
+ import torch
21
+
22
+ from diffusers import UnCLIPPipeline
23
+
24
+ # from diffusers.utils.torch_utils import randn_tensor
25
+
26
+ from transformers import CLIPTokenizer
27
+
28
+ from src.priors.prior_transformer import (
29
+ PriorTransformer,
30
+ ) # original huggingface prior transformer without time conditioning
31
+ from src.pipelines.pipeline_kandinsky_prior import KandinskyPriorPipeline
32
+
33
+ from diffusers import DiffusionPipeline
34
+
35
+
36
+ __DEVICE__ = "cpu"
37
+ if torch.cuda.is_available():
38
+ __DEVICE__ = "cuda"
39
+
40
+ class Ours:
41
+ def __init__(self, device):
42
+ text_encoder = (
43
+ CLIPTextModelWithProjection.from_pretrained(
44
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
45
+ projection_dim=1280,
46
+ torch_dtype=torch.float16,
47
+ )
48
+ .eval()
49
+ .requires_grad_(False)
50
+ )
51
+
52
+ tokenizer = CLIPTokenizer.from_pretrained(
53
+ "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
54
+ )
55
+
56
+ prior = PriorTransformer.from_pretrained(
57
+ "ECLIPSE-Community/ECLIPSE_KandinskyV22_Prior",
58
+ torch_dtype=torch.float16,
59
+ )
60
+
61
+ self.pipe_prior = KandinskyPriorPipeline.from_pretrained(
62
+ "kandinsky-community/kandinsky-2-2-prior",
63
+ prior=prior,
64
+ text_encoder=text_encoder,
65
+ tokenizer=tokenizer,
66
+ torch_dtype=torch.float16,
67
+ ).to(device)
68
+
69
+ self.pipe = DiffusionPipeline.from_pretrained(
70
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
71
+ ).to(device)
72
+
73
+ def inference(self, text, negative_text, steps, guidance_scale):
74
+ gen_images = []
75
+ for i in range(1):
76
+ image_emb, negative_image_emb = self.pipe_prior(
77
+ text, negative_prompt=negative_text
78
+ ).to_tuple()
79
+ image = self.pipe(
80
+ image_embeds=image_emb,
81
+ negative_image_embeds=negative_image_emb,
82
+ num_inference_steps=steps,
83
+ guidance_scale=guidance_scale,
84
+ ).images
85
+ gen_images.append(image[0])
86
+ return gen_images
87
+
88
+
89
+ selected_model = Ours(device=__DEVICE__)
90
+
91
+
92
+ def get_images(text, negative_text, steps, guidance_scale):
93
+ images = selected_model.inference(text, negative_text, steps, guidance_scale)
94
+ new_images = []
95
+ for img in images:
96
+ new_images.append(img)
97
+ return new_images[0]
98
+
99
+
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown(
102
+ """<h1 style="text-align: center;"><b><i>ECLIPSE</i>: Revisiting the Text-to-Image Prior for Effecient Image Generation</b></h1>
103
+ <h1 style='text-align: center;'><a href='https://eclipse-t2i.vercel.app/'>Project Page</a> | <a href='https://eclipse-t2i.vercel.app/'>Paper</a> </h1>
104
+ """
105
+ )
106
+
107
+ with gr.Group():
108
+ with gr.Row():
109
+ with gr.Column():
110
+ text = gr.Textbox(
111
+ label="Enter your prompt",
112
+ show_label=False,
113
+ max_lines=1,
114
+ placeholder="Enter your prompt",
115
+ elem_id="prompt-text-input",
116
+ ).style(
117
+ border=(True, False, True, True),
118
+ rounded=(True, False, False, True),
119
+ container=False,
120
+ )
121
+
122
+ with gr.Row():
123
+ with gr.Column():
124
+ negative_text = gr.Textbox(
125
+ label="Enter your negative prompt",
126
+ show_label=False,
127
+ max_lines=1,
128
+ placeholder="Enter your negative prompt",
129
+ elem_id="prompt-text-input",
130
+ ).style(
131
+ border=(True, False, True, True),
132
+ rounded=(True, False, False, True),
133
+ container=False,
134
+ )
135
+
136
+ with gr.Row():
137
+ steps = gr.Slider(label="Steps", minimum=10, maximum=100, value=50, step=1)
138
+ guidance_scale = gr.Slider(
139
+ label="Guidance Scale", minimum=0, maximum=10, value=7.5, step=0.1
140
+ )
141
+
142
+ with gr.Row():
143
+ btn = gr.Button(value="Generate Image", full_width=False)
144
+
145
+ gallery = gr.Image(
146
+ height=512, width=512, label="Generated images", show_label=True, elem_id="gallery"
147
+ ).style(preview=False, columns=1)
148
+
149
+ btn.click(
150
+ get_images,
151
+ inputs=[
152
+ text,
153
+ negative_text,
154
+ steps,
155
+ guidance_scale,
156
+ ],
157
+ outputs=gallery,
158
+ )
159
+ text.submit(
160
+ get_images,
161
+ inputs=[
162
+ text,
163
+ negative_text,
164
+ steps,
165
+ guidance_scale,
166
+ ],
167
+ outputs=gallery,
168
+ )
169
+ negative_text.submit(
170
+ get_images,
171
+ inputs=[
172
+ text,
173
+ negative_text,
174
+ steps,
175
+ guidance_scale,
176
+ ],
177
+ outputs=gallery,
178
+ )
179
+
180
+ with gr.Accordion(label="Ethics & Privacy", open=False):
181
+ gr.HTML(
182
+ """<div class="acknowledgments">
183
+ <p><h4>Privacy</h4>
184
+ We do not collect any images or key data. This demo is designed with sole purpose of fun and reducing misuse of AI.
185
+ <p><h4>Biases and content acknowledgment</h4>
186
+ This model will have the same biases as pre-trained CLIP model. </div>
187
+ """
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.0
2
+ datasets==2.14.6
3
+ diffusers==0.20.2
4
+ numpy==1.26.1
5
+ packaging==23.2
6
+ pandas_stubs==1.2.0.57
7
+ Pillow==10.1.0
8
+ torch==2.0.0
9
+ torchvision==0.15.1
10
+ tqdm==4.66.1
11
+ transformers==4.34.1
12
+ gradio
13
+ jmespath
14
+ opencv-python
15
+ PyWavelet
16
+ gradio==3.47.1
src/pipelines/__init__.py ADDED
File without changes
src/pipelines/pipeline_kandinsky_prior.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from transformers import (
8
+ CLIPImageProcessor,
9
+ CLIPTextModelWithProjection,
10
+ CLIPTokenizer,
11
+ CLIPVisionModelWithProjection,
12
+ )
13
+
14
+ from diffusers.models import PriorTransformer
15
+ from diffusers.schedulers import UnCLIPScheduler
16
+ from diffusers.utils import (
17
+ BaseOutput,
18
+ is_accelerate_available,
19
+ is_accelerate_version,
20
+ logging,
21
+ randn_tensor,
22
+ replace_example_docstring,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```py
32
+ >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
33
+ >>> import torch
34
+
35
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
36
+ >>> pipe_prior.to("cuda")
37
+
38
+ >>> prompt = "red cat, 4k photo"
39
+ >>> out = pipe_prior(prompt)
40
+ >>> image_emb = out.image_embeds
41
+ >>> negative_image_emb = out.negative_image_embeds
42
+
43
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
44
+ >>> pipe.to("cuda")
45
+
46
+ >>> image = pipe(
47
+ ... prompt,
48
+ ... image_embeds=image_emb,
49
+ ... negative_image_embeds=negative_image_emb,
50
+ ... height=768,
51
+ ... width=768,
52
+ ... num_inference_steps=100,
53
+ ... ).images
54
+
55
+ >>> image[0].save("cat.png")
56
+ ```
57
+ """
58
+
59
+ EXAMPLE_INTERPOLATE_DOC_STRING = """
60
+ Examples:
61
+ ```py
62
+ >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
63
+ >>> from diffusers.utils import load_image
64
+ >>> import PIL
65
+
66
+ >>> import torch
67
+ >>> from torchvision import transforms
68
+
69
+ >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
70
+ ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
71
+ ... )
72
+ >>> pipe_prior.to("cuda")
73
+
74
+ >>> img1 = load_image(
75
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
76
+ ... "/kandinsky/cat.png"
77
+ ... )
78
+
79
+ >>> img2 = load_image(
80
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
81
+ ... "/kandinsky/starry_night.jpeg"
82
+ ... )
83
+
84
+ >>> images_texts = ["a cat", img1, img2]
85
+ >>> weights = [0.3, 0.3, 0.4]
86
+ >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)
87
+
88
+ >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
89
+ >>> pipe.to("cuda")
90
+
91
+ >>> image = pipe(
92
+ ... "",
93
+ ... image_embeds=image_emb,
94
+ ... negative_image_embeds=zero_image_emb,
95
+ ... height=768,
96
+ ... width=768,
97
+ ... num_inference_steps=150,
98
+ ... ).images[0]
99
+
100
+ >>> image.save("starry_cat.png")
101
+ ```
102
+ """
103
+
104
+
105
+ @dataclass
106
+ class KandinskyPriorPipelineOutput(BaseOutput):
107
+ """
108
+ Output class for KandinskyPriorPipeline.
109
+
110
+ Args:
111
+ image_embeds (`torch.FloatTensor`)
112
+ clip image embeddings for text prompt
113
+ negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`)
114
+ clip image embeddings for unconditional tokens
115
+ """
116
+
117
+ image_embeds: Union[torch.FloatTensor, np.ndarray]
118
+ negative_image_embeds: Union[torch.FloatTensor, np.ndarray]
119
+
120
+
121
+ class KandinskyPriorPipeline(DiffusionPipeline):
122
+ """
123
+ Pipeline for generating image prior for Kandinsky
124
+
125
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
126
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
127
+
128
+ Args:
129
+ prior ([`PriorTransformer`]):
130
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
131
+ image_encoder ([`CLIPVisionModelWithProjection`]):
132
+ Frozen image-encoder.
133
+ text_encoder ([`CLIPTextModelWithProjection`]):
134
+ Frozen text-encoder.
135
+ tokenizer (`CLIPTokenizer`):
136
+ Tokenizer of class
137
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
138
+ scheduler ([`UnCLIPScheduler`]):
139
+ A scheduler to be used in combination with `prior` to generate image embedding.
140
+ """
141
+
142
+ _exclude_from_cpu_offload = ["prior"]
143
+
144
+ def __init__(
145
+ self,
146
+ prior: PriorTransformer,
147
+ image_encoder: CLIPVisionModelWithProjection,
148
+ text_encoder: CLIPTextModelWithProjection,
149
+ tokenizer: CLIPTokenizer,
150
+ scheduler: UnCLIPScheduler,
151
+ image_processor: CLIPImageProcessor,
152
+ ):
153
+ super().__init__()
154
+
155
+ self.register_modules(
156
+ prior=prior,
157
+ text_encoder=text_encoder,
158
+ tokenizer=tokenizer,
159
+ scheduler=scheduler,
160
+ image_encoder=image_encoder,
161
+ image_processor=image_processor,
162
+ )
163
+
164
+ @torch.no_grad()
165
+ @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING)
166
+ def interpolate(
167
+ self,
168
+ images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]],
169
+ weights: List[float],
170
+ num_images_per_prompt: int = 1,
171
+ num_inference_steps: int = 25,
172
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
+ latents: Optional[torch.FloatTensor] = None,
174
+ negative_prior_prompt: Optional[str] = None,
175
+ negative_prompt: str = "",
176
+ guidance_scale: float = 4.0,
177
+ device=None,
178
+ ):
179
+ """
180
+ Function invoked when using the prior pipeline for interpolation.
181
+
182
+ Args:
183
+ images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`):
184
+ list of prompts and images to guide the image generation.
185
+ weights: (`List[float]`):
186
+ list of weights for each condition in `images_and_prompts`
187
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
188
+ The number of images to generate per prompt.
189
+ num_inference_steps (`int`, *optional*, defaults to 25):
190
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191
+ expense of slower inference.
192
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
193
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
194
+ to make generation deterministic.
195
+ latents (`torch.FloatTensor`, *optional*):
196
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
197
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
198
+ tensor will ge generated by sampling using the supplied random `generator`.
199
+ negative_prior_prompt (`str`, *optional*):
200
+ The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
201
+ `guidance_scale` is less than `1`).
202
+ negative_prompt (`str` or `List[str]`, *optional*):
203
+ The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
204
+ `guidance_scale` is less than `1`).
205
+ guidance_scale (`float`, *optional*, defaults to 4.0):
206
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
207
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
208
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
209
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
210
+ usually at the expense of lower image quality.
211
+
212
+ Examples:
213
+
214
+ Returns:
215
+ [`KandinskyPriorPipelineOutput`] or `tuple`
216
+ """
217
+
218
+ device = device or self.device
219
+
220
+ if len(images_and_prompts) != len(weights):
221
+ raise ValueError(
222
+ f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length"
223
+ )
224
+
225
+ image_embeddings = []
226
+ for cond, weight in zip(images_and_prompts, weights):
227
+ if isinstance(cond, str):
228
+ image_emb = self(
229
+ cond,
230
+ num_inference_steps=num_inference_steps,
231
+ num_images_per_prompt=num_images_per_prompt,
232
+ generator=generator,
233
+ latents=latents,
234
+ negative_prompt=negative_prior_prompt,
235
+ guidance_scale=guidance_scale,
236
+ ).image_embeds
237
+
238
+ elif isinstance(cond, (PIL.Image.Image, torch.Tensor)):
239
+ if isinstance(cond, PIL.Image.Image):
240
+ cond = (
241
+ self.image_processor(cond, return_tensors="pt")
242
+ .pixel_values[0]
243
+ .unsqueeze(0)
244
+ .to(dtype=self.image_encoder.dtype, device=device)
245
+ )
246
+
247
+ image_emb = self.image_encoder(cond)["image_embeds"]
248
+
249
+ else:
250
+ raise ValueError(
251
+ f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}"
252
+ )
253
+
254
+ image_embeddings.append(image_emb * weight)
255
+
256
+ image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True)
257
+
258
+ out_zero = self(
259
+ negative_prompt,
260
+ num_inference_steps=num_inference_steps,
261
+ num_images_per_prompt=num_images_per_prompt,
262
+ generator=generator,
263
+ latents=latents,
264
+ negative_prompt=negative_prior_prompt,
265
+ guidance_scale=guidance_scale,
266
+ )
267
+ zero_image_emb = (
268
+ out_zero.negative_image_embeds
269
+ if negative_prompt == ""
270
+ else out_zero.image_embeds
271
+ )
272
+
273
+ return KandinskyPriorPipelineOutput(
274
+ image_embeds=image_emb, negative_image_embeds=zero_image_emb
275
+ )
276
+
277
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
278
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
279
+ if latents is None:
280
+ latents = randn_tensor(
281
+ shape, generator=generator, device=device, dtype=dtype
282
+ )
283
+ else:
284
+ if latents.shape != shape:
285
+ raise ValueError(
286
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
287
+ )
288
+ latents = latents.to(device)
289
+
290
+ latents = latents * scheduler.init_noise_sigma
291
+ return latents
292
+
293
+ def get_zero_embed(self, batch_size=1, device=None):
294
+ device = device or self.device
295
+ zero_img = torch.zeros(
296
+ 1,
297
+ 3,
298
+ self.image_encoder.config.image_size,
299
+ self.image_encoder.config.image_size,
300
+ ).to(device=device, dtype=self.image_encoder.dtype)
301
+ zero_image_emb = self.image_encoder(zero_img)["image_embeds"]
302
+ zero_image_emb = zero_image_emb.repeat(batch_size, 1)
303
+ return zero_image_emb
304
+
305
+ def _encode_prompt(
306
+ self,
307
+ prompt,
308
+ device,
309
+ num_images_per_prompt,
310
+ do_classifier_free_guidance,
311
+ negative_prompt=None,
312
+ ):
313
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
314
+ # get prompt text embeddings
315
+ text_inputs = self.tokenizer(
316
+ prompt,
317
+ padding="max_length",
318
+ max_length=self.tokenizer.model_max_length,
319
+ truncation=True,
320
+ return_tensors="pt",
321
+ )
322
+ text_input_ids = text_inputs.input_ids
323
+ text_mask = text_inputs.attention_mask.bool().to(device)
324
+
325
+ untruncated_ids = self.tokenizer(
326
+ prompt, padding="longest", return_tensors="pt"
327
+ ).input_ids
328
+
329
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
330
+ text_input_ids, untruncated_ids
331
+ ):
332
+ removed_text = self.tokenizer.batch_decode(
333
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
334
+ )
335
+ logger.warning(
336
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
337
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
338
+ )
339
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
340
+
341
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
342
+
343
+ prompt_embeds = text_encoder_output.text_embeds
344
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
345
+
346
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
347
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
348
+ num_images_per_prompt, dim=0
349
+ )
350
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
351
+
352
+ if do_classifier_free_guidance:
353
+ uncond_tokens: List[str]
354
+ if negative_prompt is None:
355
+ uncond_tokens = [""] * batch_size
356
+ elif type(prompt) is not type(negative_prompt):
357
+ raise TypeError(
358
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
359
+ f" {type(prompt)}."
360
+ )
361
+ elif isinstance(negative_prompt, str):
362
+ uncond_tokens = [negative_prompt]
363
+ elif batch_size != len(negative_prompt):
364
+ raise ValueError(
365
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
366
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
367
+ " the batch size of `prompt`."
368
+ )
369
+ else:
370
+ uncond_tokens = negative_prompt
371
+
372
+ uncond_input = self.tokenizer(
373
+ uncond_tokens,
374
+ padding="max_length",
375
+ max_length=self.tokenizer.model_max_length,
376
+ truncation=True,
377
+ return_tensors="pt",
378
+ )
379
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
380
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
381
+ uncond_input.input_ids.to(device)
382
+ )
383
+
384
+ negative_prompt_embeds = (
385
+ negative_prompt_embeds_text_encoder_output.text_embeds
386
+ )
387
+ uncond_text_encoder_hidden_states = (
388
+ negative_prompt_embeds_text_encoder_output.last_hidden_state
389
+ )
390
+
391
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
+
393
+ seq_len = negative_prompt_embeds.shape[1]
394
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
395
+ 1, num_images_per_prompt
396
+ )
397
+ negative_prompt_embeds = negative_prompt_embeds.view(
398
+ batch_size * num_images_per_prompt, seq_len
399
+ )
400
+
401
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
402
+ uncond_text_encoder_hidden_states = (
403
+ uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
404
+ )
405
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
406
+ batch_size * num_images_per_prompt, seq_len, -1
407
+ )
408
+ uncond_text_mask = uncond_text_mask.repeat_interleave(
409
+ num_images_per_prompt, dim=0
410
+ )
411
+
412
+ # done duplicates
413
+
414
+ # For classifier free guidance, we need to do two forward passes.
415
+ # Here we concatenate the unconditional and text embeddings into a single batch
416
+ # to avoid doing two forward passes
417
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
418
+ text_encoder_hidden_states = torch.cat(
419
+ [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
420
+ )
421
+
422
+ text_mask = torch.cat([uncond_text_mask, text_mask])
423
+
424
+ return prompt_embeds, text_encoder_hidden_states, text_mask
425
+
426
+ def enable_model_cpu_offload(self, gpu_id=0):
427
+ r"""
428
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
429
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
430
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
431
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
432
+ """
433
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
434
+ from accelerate import cpu_offload_with_hook
435
+ else:
436
+ raise ImportError(
437
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
438
+ )
439
+
440
+ device = torch.device(f"cuda:{gpu_id}")
441
+
442
+ if self.device.type != "cpu":
443
+ self.to("cpu", silence_dtype_warnings=True)
444
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
445
+
446
+ hook = None
447
+ for cpu_offloaded_model in [self.text_encoder, self.prior]:
448
+ _, hook = cpu_offload_with_hook(
449
+ cpu_offloaded_model, device, prev_module_hook=hook
450
+ )
451
+
452
+ # We'll offload the last model manually.
453
+ self.prior_hook = hook
454
+
455
+ _, hook = cpu_offload_with_hook(
456
+ self.image_encoder, device, prev_module_hook=self.prior_hook
457
+ )
458
+
459
+ self.final_offload_hook = hook
460
+
461
+ @torch.no_grad()
462
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
463
+ def __call__(
464
+ self,
465
+ prompt: Union[str, List[str]],
466
+ negative_prompt: Optional[Union[str, List[str]]] = None,
467
+ num_images_per_prompt: int = 1,
468
+ num_inference_steps: int = 25,
469
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
470
+ latents: Optional[torch.FloatTensor] = None,
471
+ guidance_scale: float = 4.0,
472
+ output_type: Optional[str] = "pt",
473
+ return_dict: bool = True,
474
+ ):
475
+ """
476
+ Function invoked when calling the pipeline for generation.
477
+
478
+ Args:
479
+ prompt (`str` or `List[str]`):
480
+ The prompt or prompts to guide the image generation.
481
+ negative_prompt (`str` or `List[str]`, *optional*):
482
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
483
+ if `guidance_scale` is less than `1`).
484
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
485
+ The number of images to generate per prompt.
486
+ num_inference_steps (`int`, *optional*, defaults to 25):
487
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
488
+ expense of slower inference.
489
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
490
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
491
+ to make generation deterministic.
492
+ latents (`torch.FloatTensor`, *optional*):
493
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
494
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
495
+ tensor will ge generated by sampling using the supplied random `generator`.
496
+ guidance_scale (`float`, *optional*, defaults to 4.0):
497
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
498
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
499
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
500
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
501
+ usually at the expense of lower image quality.
502
+ output_type (`str`, *optional*, defaults to `"pt"`):
503
+ The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
504
+ (`torch.Tensor`).
505
+ return_dict (`bool`, *optional*, defaults to `True`):
506
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
507
+
508
+ Examples:
509
+
510
+ Returns:
511
+ [`KandinskyPriorPipelineOutput`] or `tuple`
512
+ """
513
+
514
+ if isinstance(prompt, str):
515
+ prompt = [prompt]
516
+ elif not isinstance(prompt, list):
517
+ raise ValueError(
518
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
519
+ )
520
+
521
+ if isinstance(negative_prompt, str):
522
+ negative_prompt = [negative_prompt]
523
+ elif not isinstance(negative_prompt, list) and negative_prompt is not None:
524
+ raise ValueError(
525
+ f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}"
526
+ )
527
+
528
+ # if the negative prompt is defined we double the batch size to
529
+ # directly retrieve the negative prompt embedding
530
+ if negative_prompt is not None:
531
+ prompt = prompt + negative_prompt
532
+ negative_prompt = 2 * negative_prompt
533
+
534
+ device = self._execution_device
535
+
536
+ batch_size = len(prompt)
537
+ batch_size = batch_size * num_images_per_prompt
538
+
539
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
540
+ prompt, device, num_images_per_prompt, False, negative_prompt
541
+ )
542
+
543
+ hidden_states = randn_tensor(
544
+ (batch_size, prompt_embeds.shape[-1]),
545
+ device=prompt_embeds.device,
546
+ dtype=prompt_embeds.dtype,
547
+ generator=generator,
548
+ )
549
+
550
+ latents = self.prior(
551
+ hidden_states,
552
+ proj_embedding=prompt_embeds,
553
+ encoder_hidden_states=text_encoder_hidden_states,
554
+ attention_mask=text_mask,
555
+ ).predicted_image_embedding
556
+
557
+ image_embeddings = latents
558
+
559
+ # if negative prompt has been defined, we retrieve split the image embedding into two
560
+ if negative_prompt is None:
561
+ zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
562
+
563
+ if (
564
+ hasattr(self, "final_offload_hook")
565
+ and self.final_offload_hook is not None
566
+ ):
567
+ self.final_offload_hook.offload()
568
+ else:
569
+ image_embeddings, zero_embeds = image_embeddings.chunk(2)
570
+
571
+ if (
572
+ hasattr(self, "final_offload_hook")
573
+ and self.final_offload_hook is not None
574
+ ):
575
+ self.prior_hook.offload()
576
+
577
+ if output_type not in ["pt", "np"]:
578
+ raise ValueError(
579
+ f"Only the output types `pt` and `np` are supported not output_type={output_type}"
580
+ )
581
+
582
+ if output_type == "np":
583
+ image_embeddings = image_embeddings.cpu().numpy()
584
+ zero_embeds = zero_embeds.cpu().numpy()
585
+
586
+ if not return_dict:
587
+ return (image_embeddings, zero_embeds)
588
+
589
+ return KandinskyPriorPipelineOutput(
590
+ image_embeds=image_embeddings, negative_image_embeds=zero_embeds
591
+ )
src/pipelines/pipeline_unclip.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+
5
+ import inspect
6
+ from typing import List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+ from transformers import CLIPTextModelWithProjection, CLIPTokenizer
11
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
12
+
13
+ from diffusers.models import UNet2DConditionModel, UNet2DModel
14
+ from diffusers.schedulers import UnCLIPScheduler
15
+ from diffusers.utils import logging, randn_tensor
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
17
+ from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
18
+
19
+
20
+ from diffusers.models import PriorTransformer
21
+
22
+
23
+ import torch
24
+ from torchvision.transforms import ToPILImage
25
+
26
+ import copy
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class UnCLIPPipeline(DiffusionPipeline):
32
+ """
33
+ Pipeline for text-to-image generation using unCLIP.
34
+
35
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
36
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
37
+
38
+ Args:
39
+ text_encoder ([`~transformers.CLIPTextModelWithProjection`]):
40
+ Frozen text-encoder.
41
+ tokenizer ([`~transformers.CLIPTokenizer`]):
42
+ A `CLIPTokenizer` to tokenize text.
43
+ prior ([`PriorTransformer`]):
44
+ The canonical unCLIP prior to approximate the image embedding from the text embedding.
45
+ text_proj ([`UnCLIPTextProjModel`]):
46
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
47
+ decoder ([`UNet2DConditionModel`]):
48
+ The decoder to invert the image embedding into an image.
49
+ super_res_first ([`UNet2DModel`]):
50
+ Super resolution UNet. Used in all but the last step of the super resolution diffusion process.
51
+ super_res_last ([`UNet2DModel`]):
52
+ Super resolution UNet. Used in the last step of the super resolution diffusion process.
53
+ prior_scheduler ([`UnCLIPScheduler`]):
54
+ Scheduler used in the prior denoising process (a modified [`DDPMScheduler`]).
55
+ decoder_scheduler ([`UnCLIPScheduler`]):
56
+ Scheduler used in the decoder denoising process (a modified [`DDPMScheduler`]).
57
+ super_res_scheduler ([`UnCLIPScheduler`]):
58
+ Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
59
+
60
+ """
61
+
62
+ _exclude_from_cpu_offload = ["prior"]
63
+
64
+ prior: PriorTransformer
65
+ decoder: UNet2DConditionModel
66
+ text_proj: UnCLIPTextProjModel
67
+ text_encoder: CLIPTextModelWithProjection
68
+ tokenizer: CLIPTokenizer
69
+ super_res_first: UNet2DModel
70
+ super_res_last: UNet2DModel
71
+
72
+ prior_scheduler: UnCLIPScheduler
73
+ decoder_scheduler: UnCLIPScheduler
74
+ super_res_scheduler: UnCLIPScheduler
75
+
76
+ def __init__(
77
+ self,
78
+ prior: PriorTransformer,
79
+ decoder: UNet2DConditionModel,
80
+ text_encoder: CLIPTextModelWithProjection,
81
+ tokenizer: CLIPTokenizer,
82
+ text_proj: UnCLIPTextProjModel,
83
+ super_res_first: UNet2DModel,
84
+ super_res_last: UNet2DModel,
85
+ prior_scheduler: UnCLIPScheduler,
86
+ decoder_scheduler: UnCLIPScheduler,
87
+ super_res_scheduler: UnCLIPScheduler,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.register_modules(
92
+ prior=prior,
93
+ decoder=decoder,
94
+ text_encoder=text_encoder,
95
+ tokenizer=tokenizer,
96
+ text_proj=text_proj,
97
+ super_res_first=super_res_first,
98
+ super_res_last=super_res_last,
99
+ prior_scheduler=prior_scheduler,
100
+ decoder_scheduler=decoder_scheduler,
101
+ super_res_scheduler=super_res_scheduler,
102
+ )
103
+
104
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
105
+ if latents is None:
106
+ latents = randn_tensor(
107
+ shape, generator=generator, device=device, dtype=dtype
108
+ )
109
+ else:
110
+ if latents.shape != shape:
111
+ raise ValueError(
112
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
113
+ )
114
+ latents = latents.to(device)
115
+
116
+ latents = latents * scheduler.init_noise_sigma
117
+ return latents
118
+
119
+ def _encode_prompt(
120
+ self,
121
+ prompt,
122
+ device,
123
+ num_images_per_prompt,
124
+ do_classifier_free_guidance,
125
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
126
+ text_attention_mask: Optional[torch.Tensor] = None,
127
+ ):
128
+ if text_model_output is None:
129
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
130
+ # get prompt text embeddings
131
+ text_inputs = self.tokenizer(
132
+ prompt,
133
+ padding="max_length",
134
+ max_length=self.tokenizer.model_max_length,
135
+ truncation=True,
136
+ return_tensors="pt",
137
+ )
138
+ text_input_ids = text_inputs.input_ids
139
+ text_mask = text_inputs.attention_mask.bool().to(device)
140
+
141
+ untruncated_ids = self.tokenizer(
142
+ prompt, padding="longest", return_tensors="pt"
143
+ ).input_ids
144
+
145
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
146
+ -1
147
+ ] and not torch.equal(text_input_ids, untruncated_ids):
148
+ removed_text = self.tokenizer.batch_decode(
149
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
150
+ )
151
+ logger.warning(
152
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
153
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
154
+ )
155
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
156
+
157
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
158
+
159
+ prompt_embeds = text_encoder_output.text_embeds
160
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
161
+
162
+ else:
163
+ batch_size = text_model_output[0].shape[0]
164
+ prompt_embeds, text_encoder_hidden_states = (
165
+ text_model_output[0],
166
+ text_model_output[1],
167
+ )
168
+ text_mask = text_attention_mask
169
+
170
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
171
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(
172
+ num_images_per_prompt, dim=0
173
+ )
174
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
175
+
176
+ if do_classifier_free_guidance:
177
+ uncond_tokens = [""] * batch_size
178
+
179
+ uncond_input = self.tokenizer(
180
+ uncond_tokens,
181
+ padding="max_length",
182
+ max_length=self.tokenizer.model_max_length,
183
+ truncation=True,
184
+ return_tensors="pt",
185
+ )
186
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
187
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
188
+ uncond_input.input_ids.to(device)
189
+ )
190
+
191
+ negative_prompt_embeds = (
192
+ negative_prompt_embeds_text_encoder_output.text_embeds
193
+ )
194
+ uncond_text_encoder_hidden_states = (
195
+ negative_prompt_embeds_text_encoder_output.last_hidden_state
196
+ )
197
+
198
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
199
+
200
+ seq_len = negative_prompt_embeds.shape[1]
201
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
202
+ 1, num_images_per_prompt
203
+ )
204
+ negative_prompt_embeds = negative_prompt_embeds.view(
205
+ batch_size * num_images_per_prompt, seq_len
206
+ )
207
+
208
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
209
+ uncond_text_encoder_hidden_states = (
210
+ uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
211
+ )
212
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
213
+ batch_size * num_images_per_prompt, seq_len, -1
214
+ )
215
+ uncond_text_mask = uncond_text_mask.repeat_interleave(
216
+ num_images_per_prompt, dim=0
217
+ )
218
+
219
+ # done duplicates
220
+
221
+ # For classifier free guidance, we need to do two forward passes.
222
+ # Here we concatenate the unconditional and text embeddings into a single batch
223
+ # to avoid doing two forward passes
224
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
225
+ text_encoder_hidden_states = torch.cat(
226
+ [uncond_text_encoder_hidden_states, text_encoder_hidden_states]
227
+ )
228
+
229
+ text_mask = torch.cat([uncond_text_mask, text_mask])
230
+
231
+ return prompt_embeds, text_encoder_hidden_states, text_mask
232
+
233
+ @torch.no_grad()
234
+ def __call__(
235
+ self,
236
+ prompt: Optional[Union[str, List[str]]] = None,
237
+ num_images_per_prompt: int = 1,
238
+ prior_num_inference_steps: int = 25,
239
+ decoder_num_inference_steps: int = 25,
240
+ super_res_num_inference_steps: int = 7,
241
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
242
+ prior_latents: Optional[torch.FloatTensor] = None,
243
+ decoder_latents: Optional[torch.FloatTensor] = None,
244
+ super_res_latents: Optional[torch.FloatTensor] = None,
245
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
246
+ text_attention_mask: Optional[torch.Tensor] = None,
247
+ prior_guidance_scale: float = 4.0,
248
+ decoder_guidance_scale: float = 8.0,
249
+ output_type: Optional[str] = "pil",
250
+ return_dict: bool = True,
251
+ null_prompt_decoder: bool = False,
252
+ ):
253
+ """
254
+ The call function to the pipeline for generation.
255
+
256
+ Args:
257
+ prompt (`str` or `List[str]`):
258
+ The prompt or prompts to guide image generation. This can only be left undefined if `text_model_output`
259
+ and `text_attention_mask` is passed.
260
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
261
+ The number of images to generate per prompt.
262
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
263
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
264
+ image at the expense of slower inference.
265
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
266
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
267
+ image at the expense of slower inference.
268
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
269
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
270
+ quality image at the expense of slower inference.
271
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
272
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
273
+ generation deterministic.
274
+ prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
275
+ Pre-generated noisy latents to be used as inputs for the prior.
276
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
277
+ Pre-generated noisy latents to be used as inputs for the decoder.
278
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
279
+ Pre-generated noisy latents to be used as inputs for the decoder.
280
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
281
+ A higher guidance scale value encourages the model to generate images closely linked to the text
282
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
283
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
284
+ A higher guidance scale value encourages the model to generate images closely linked to the text
285
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
286
+ text_model_output (`CLIPTextModelOutput`, *optional*):
287
+ Pre-defined [`CLIPTextModel`] outputs that can be derived from the text encoder. Pre-defined text
288
+ outputs can be passed for tasks like text embedding interpolations. Make sure to also pass
289
+ `text_attention_mask` in this case. `prompt` can the be left `None`.
290
+ text_attention_mask (`torch.Tensor`, *optional*):
291
+ Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
292
+ masks are necessary when passing `text_model_output`.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
295
+ return_dict (`bool`, *optional*, defaults to `True`):
296
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
297
+
298
+ Returns:
299
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
300
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
301
+ returned where the first element is a list with the generated images.
302
+ """
303
+ if prompt is not None:
304
+ if isinstance(prompt, str):
305
+ batch_size = 1
306
+ elif isinstance(prompt, list):
307
+ batch_size = len(prompt)
308
+ else:
309
+ raise ValueError(
310
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
311
+ )
312
+ else:
313
+ batch_size = text_model_output[0].shape[0]
314
+
315
+ device = self._execution_device
316
+
317
+ batch_size = batch_size * num_images_per_prompt
318
+
319
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
320
+ prompt,
321
+ device,
322
+ num_images_per_prompt,
323
+ False,
324
+ text_model_output,
325
+ text_attention_mask,
326
+ )
327
+
328
+ hidden_states = randn_tensor(
329
+ (batch_size, prompt_embeds.shape[-1]),
330
+ device=prompt_embeds.device,
331
+ dtype=prompt_embeds.dtype,
332
+ generator=generator,
333
+ )
334
+
335
+ prior_latents = self.prior(
336
+ hidden_states,
337
+ proj_embedding=prompt_embeds,
338
+ encoder_hidden_states=text_encoder_hidden_states,
339
+ attention_mask=text_mask,
340
+ ).predicted_image_embedding
341
+
342
+ do_classifier_free_guidance = decoder_guidance_scale > 1.0
343
+ prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
344
+ prompt if not null_prompt_decoder else "",
345
+ device,
346
+ num_images_per_prompt,
347
+ do_classifier_free_guidance,
348
+ text_model_output,
349
+ text_attention_mask,
350
+ )
351
+
352
+ prior_latents = prior_latents.expand(
353
+ (
354
+ prompt_embeds.shape[0] // 2
355
+ if do_classifier_free_guidance
356
+ else prompt_embeds.shape[0],
357
+ prompt_embeds.shape[1],
358
+ )
359
+ )
360
+ image_embeddings = prior_latents.clone()
361
+ # return image_embeddings
362
+
363
+ # decoder
364
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
365
+ image_embeddings=image_embeddings,
366
+ prompt_embeds=prompt_embeds,
367
+ text_encoder_hidden_states=text_encoder_hidden_states,
368
+ do_classifier_free_guidance=do_classifier_free_guidance,
369
+ )
370
+
371
+ if device.type == "mps":
372
+ # HACK: MPS: There is a panic when padding bool tensors,
373
+ # so cast to int tensor for the pad and back to bool afterwards
374
+ text_mask = text_mask.type(torch.int)
375
+ decoder_text_mask = F.pad(
376
+ text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1
377
+ )
378
+ decoder_text_mask = decoder_text_mask.type(torch.bool)
379
+ else:
380
+ decoder_text_mask = F.pad(
381
+ text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=True
382
+ )
383
+
384
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
385
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
386
+
387
+ num_channels_latents = self.decoder.config.in_channels
388
+ height = self.decoder.config.sample_size
389
+ width = self.decoder.config.sample_size
390
+
391
+ decoder_latents = self.prepare_latents(
392
+ (batch_size, num_channels_latents, height, width),
393
+ text_encoder_hidden_states.dtype,
394
+ device,
395
+ generator,
396
+ decoder_latents,
397
+ self.decoder_scheduler,
398
+ )
399
+
400
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
401
+ # expand the latents if we are doing classifier free guidance
402
+ latent_model_input = (
403
+ torch.cat([decoder_latents] * 2)
404
+ if do_classifier_free_guidance
405
+ else decoder_latents
406
+ )
407
+
408
+ noise_pred = self.decoder(
409
+ sample=latent_model_input,
410
+ timestep=t,
411
+ encoder_hidden_states=text_encoder_hidden_states,
412
+ class_labels=additive_clip_time_embeddings,
413
+ attention_mask=decoder_text_mask,
414
+ ).sample
415
+
416
+ if do_classifier_free_guidance:
417
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
418
+ noise_pred_uncond, _ = noise_pred_uncond.split(
419
+ latent_model_input.shape[1], dim=1
420
+ )
421
+ noise_pred_text, predicted_variance = noise_pred_text.split(
422
+ latent_model_input.shape[1], dim=1
423
+ )
424
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (
425
+ noise_pred_text - noise_pred_uncond
426
+ )
427
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
428
+
429
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
430
+ prev_timestep = None
431
+ else:
432
+ prev_timestep = decoder_timesteps_tensor[i + 1]
433
+
434
+ # compute the previous noisy sample x_t -> x_t-1
435
+ decoder_latents = self.decoder_scheduler.step(
436
+ noise_pred,
437
+ t,
438
+ decoder_latents,
439
+ prev_timestep=prev_timestep,
440
+ generator=generator,
441
+ ).prev_sample
442
+
443
+ decoder_latents = decoder_latents.clamp(-1, 1)
444
+
445
+ image_small = decoder_latents
446
+
447
+ # done decoder
448
+
449
+ # super res
450
+
451
+ self.super_res_scheduler.set_timesteps(
452
+ super_res_num_inference_steps, device=device
453
+ )
454
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
455
+
456
+ channels = self.super_res_first.config.in_channels // 2
457
+ height = self.super_res_first.config.sample_size
458
+ width = self.super_res_first.config.sample_size
459
+
460
+ super_res_latents = self.prepare_latents(
461
+ (batch_size, channels, height, width),
462
+ image_small.dtype,
463
+ device,
464
+ generator,
465
+ super_res_latents,
466
+ self.super_res_scheduler,
467
+ )
468
+
469
+ if device.type == "mps":
470
+ # MPS does not support many interpolations
471
+ image_upscaled = F.interpolate(image_small, size=[height, width])
472
+ else:
473
+ interpolate_antialias = {}
474
+ if "antialias" in inspect.signature(F.interpolate).parameters:
475
+ interpolate_antialias["antialias"] = True
476
+
477
+ image_upscaled = F.interpolate(
478
+ image_small,
479
+ size=[height, width],
480
+ mode="bicubic",
481
+ align_corners=False,
482
+ **interpolate_antialias,
483
+ )
484
+
485
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
486
+ # no classifier free guidance
487
+
488
+ if i == super_res_timesteps_tensor.shape[0] - 1:
489
+ unet = self.super_res_last
490
+ else:
491
+ unet = self.super_res_first
492
+
493
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
494
+
495
+ noise_pred = unet(
496
+ sample=latent_model_input,
497
+ timestep=t,
498
+ ).sample
499
+
500
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
501
+ prev_timestep = None
502
+ else:
503
+ prev_timestep = super_res_timesteps_tensor[i + 1]
504
+
505
+ # compute the previous noisy sample x_t -> x_t-1
506
+ super_res_latents = self.super_res_scheduler.step(
507
+ noise_pred,
508
+ t,
509
+ super_res_latents,
510
+ prev_timestep=prev_timestep,
511
+ generator=generator,
512
+ ).prev_sample
513
+
514
+ image = super_res_latents
515
+ # done super res
516
+
517
+ # post processing
518
+
519
+ image = image * 0.5 + 0.5
520
+ image = image.clamp(0, 1)
521
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
522
+
523
+ if output_type == "pil":
524
+ image = self.numpy_to_pil(image)
525
+
526
+ if not return_dict:
527
+ return (image,)
528
+
529
+ return ImagePipelineOutput(images=image)
src/priors/__init__.py ADDED
File without changes
src/priors/prior_transformer.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("..")
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Optional, Union
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.models.attention import BasicTransformerBlock
15
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+
19
+
20
+ @dataclass
21
+ class PriorTransformerOutput(BaseOutput):
22
+ """
23
+ The output of [`PriorTransformer`].
24
+
25
+ Args:
26
+ predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
27
+ The predicted CLIP image embedding conditioned on the CLIP text embedding input.
28
+ """
29
+
30
+ predicted_image_embedding: torch.FloatTensor
31
+
32
+
33
+ class PriorTransformer(ModelMixin, ConfigMixin):
34
+ """
35
+ A Prior Transformer model.
36
+
37
+ Parameters:
38
+ num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
39
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
40
+ num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
41
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
42
+ num_embeddings (`int`, *optional*, defaults to 77):
43
+ The number of embeddings of the model input `hidden_states`
44
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
45
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
46
+ additional_embeddings`.
47
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
48
+ time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
49
+ The activation function to use to create timestep embeddings.
50
+ norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
51
+ passing to Transformer blocks. Set it to `None` if normalization is not needed.
52
+ embedding_proj_norm_type (`str`, *optional*, defaults to None):
53
+ The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
54
+ needed.
55
+ encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
56
+ The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
57
+ `encoder_hidden_states` is `None`.
58
+ added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
59
+ Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
60
+ product between the text embedding and image embedding as proposed in the unclip paper
61
+ https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
62
+ time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
63
+ If None, will be set to `num_attention_heads * attention_head_dim`
64
+ embedding_proj_dim (`int`, *optional*, default to None):
65
+ The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
66
+ clip_embed_dim (`int`, *optional*, default to None):
67
+ The dimension of the output. If None, will be set to `embedding_dim`.
68
+ """
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ num_attention_heads: int = 32,
74
+ attention_head_dim: int = 64,
75
+ num_layers: int = 20,
76
+ embedding_dim: int = 768,
77
+ num_embeddings=77,
78
+ additional_embeddings=3, # as we have remvoed the time embedding
79
+ dropout: float = 0.0,
80
+ # time_embed_act_fn: str = "silu",
81
+ norm_in_type: Optional[str] = None, # layer
82
+ embedding_proj_norm_type: Optional[str] = None, # layer
83
+ encoder_hid_proj_type: Optional[str] = "linear", # linear
84
+ added_emb_type: Optional[str] = "prd", # prd
85
+ # time_embed_dim: Optional[int] = None,
86
+ embedding_proj_dim: Optional[int] = None,
87
+ clip_embed_dim: Optional[int] = None,
88
+ ):
89
+ super().__init__()
90
+ self.num_attention_heads = num_attention_heads
91
+ self.attention_head_dim = attention_head_dim
92
+ inner_dim = num_attention_heads * attention_head_dim
93
+ self.additional_embeddings = additional_embeddings
94
+
95
+ # time_embed_dim = time_embed_dim or inner_dim
96
+ embedding_proj_dim = embedding_proj_dim or embedding_dim
97
+ clip_embed_dim = clip_embed_dim or embedding_dim
98
+
99
+ # self.time_proj = Timesteps(inner_dim, True, 0)
100
+ # self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
101
+
102
+ self.proj_in = nn.Linear(embedding_dim, inner_dim)
103
+
104
+ if embedding_proj_norm_type is None:
105
+ self.embedding_proj_norm = None
106
+ elif embedding_proj_norm_type == "layer":
107
+ self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
108
+ else:
109
+ raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
110
+
111
+ self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
112
+
113
+ if encoder_hid_proj_type is None:
114
+ self.encoder_hidden_states_proj = None
115
+ elif encoder_hid_proj_type == "linear":
116
+ self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
117
+ else:
118
+ raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
119
+
120
+ self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
121
+
122
+ if added_emb_type == "prd":
123
+ self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
124
+ elif added_emb_type is None:
125
+ self.prd_embedding = None
126
+ else:
127
+ raise ValueError(
128
+ f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
129
+ )
130
+
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [
133
+ BasicTransformerBlock(
134
+ inner_dim,
135
+ num_attention_heads,
136
+ attention_head_dim,
137
+ dropout=dropout,
138
+ activation_fn="gelu",
139
+ attention_bias=True,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+
145
+ if norm_in_type == "layer":
146
+ self.norm_in = nn.LayerNorm(inner_dim)
147
+ elif norm_in_type is None:
148
+ self.norm_in = None
149
+ else:
150
+ raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
151
+
152
+ self.norm_out = nn.LayerNorm(inner_dim)
153
+
154
+ self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
155
+
156
+ causal_attention_mask = torch.full(
157
+ [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
158
+ )
159
+ causal_attention_mask.triu_(1)
160
+ causal_attention_mask = causal_attention_mask[None, ...]
161
+ self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
162
+
163
+ self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
164
+ self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
165
+
166
+ @property
167
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
168
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
169
+ r"""
170
+ Returns:
171
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
172
+ indexed by its weight name.
173
+ """
174
+ # set recursively
175
+ processors = {}
176
+
177
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
178
+ if hasattr(module, "set_processor"):
179
+ processors[f"{name}.processor"] = module.processor
180
+
181
+ for sub_name, child in module.named_children():
182
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
183
+
184
+ return processors
185
+
186
+ for name, module in self.named_children():
187
+ fn_recursive_add_processors(name, module, processors)
188
+
189
+ return processors
190
+
191
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
192
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
193
+ r"""
194
+ Sets the attention processor to use to compute attention.
195
+
196
+ Parameters:
197
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
198
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
199
+ for **all** `Attention` layers.
200
+
201
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
202
+ processor. This is strongly recommended when setting trainable attention processors.
203
+
204
+ """
205
+ count = len(self.attn_processors.keys())
206
+
207
+ if isinstance(processor, dict) and len(processor) != count:
208
+ raise ValueError(
209
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
210
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
211
+ )
212
+
213
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
214
+ if hasattr(module, "set_processor"):
215
+ if not isinstance(processor, dict):
216
+ module.set_processor(processor)
217
+ else:
218
+ module.set_processor(processor.pop(f"{name}.processor"))
219
+
220
+ for sub_name, child in module.named_children():
221
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
222
+
223
+ for name, module in self.named_children():
224
+ fn_recursive_attn_processor(name, module, processor)
225
+
226
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
227
+ def set_default_attn_processor(self):
228
+ """
229
+ Disables custom attention processors and sets the default attention implementation.
230
+ """
231
+ self.set_attn_processor(AttnProcessor())
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states,
236
+ # timestep: Union[torch.Tensor, float, int],
237
+ proj_embedding: torch.FloatTensor,
238
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
239
+ attention_mask: Optional[torch.BoolTensor] = None,
240
+ return_dict: bool = True,
241
+ ):
242
+ """
243
+ The [`PriorTransformer`] forward method.
244
+
245
+ Args:
246
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
247
+ The currently predicted image embeddings.
248
+ timestep (`torch.LongTensor`):
249
+ Current denoising step.
250
+ proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
251
+ Projected embedding vector the denoising process is conditioned on.
252
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
253
+ Hidden states of the text embeddings the denoising process is conditioned on.
254
+ attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
255
+ Text mask for the text embeddings.
256
+ return_dict (`bool`, *optional*, defaults to `True`):
257
+ Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
258
+ tuple.
259
+
260
+ Returns:
261
+ [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
262
+ If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
263
+ tuple is returned where the first element is the sample tensor.
264
+ """
265
+ batch_size = hidden_states.shape[0]
266
+
267
+ # timesteps = timestep
268
+ # if not torch.is_tensor(timesteps):
269
+ # timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
270
+ # elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
271
+ # timesteps = timesteps[None].to(hidden_states.device)
272
+
273
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
274
+ # timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
275
+
276
+ # timesteps_projected = self.time_proj(timesteps)
277
+
278
+ # timesteps does not contain any weights and will always return f32 tensors
279
+ # but time_embedding might be fp16, so we need to cast here.
280
+ # timesteps_projected = timesteps_projected.to(dtype=self.dtype)
281
+ # time_embeddings = self.time_embedding(timesteps_projected)
282
+
283
+ if self.embedding_proj_norm is not None:
284
+ proj_embedding = self.embedding_proj_norm(proj_embedding)
285
+
286
+ proj_embeddings = self.embedding_proj(proj_embedding)
287
+ if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
288
+ encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
289
+ elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
290
+ raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
291
+
292
+ hidden_states = self.proj_in(hidden_states)
293
+
294
+ positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
295
+
296
+ additional_embeds = []
297
+ additional_embeddings_len = 0
298
+
299
+ if encoder_hidden_states is not None:
300
+ additional_embeds.append(encoder_hidden_states)
301
+ additional_embeddings_len += encoder_hidden_states.shape[1]
302
+
303
+ if len(proj_embeddings.shape) == 2:
304
+ proj_embeddings = proj_embeddings[:, None, :]
305
+
306
+ if len(hidden_states.shape) == 2:
307
+ hidden_states = hidden_states[:, None, :]
308
+
309
+ additional_embeds = additional_embeds + [
310
+ proj_embeddings,
311
+ # time_embeddings[:, None, :],
312
+ hidden_states,
313
+ ]
314
+
315
+ if self.prd_embedding is not None:
316
+ prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
317
+ additional_embeds.append(prd_embedding)
318
+
319
+ hidden_states = torch.cat(
320
+ additional_embeds,
321
+ dim=1,
322
+ )
323
+
324
+ # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
325
+ additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
326
+ if positional_embeddings.shape[1] < hidden_states.shape[1]:
327
+ positional_embeddings = F.pad(
328
+ positional_embeddings,
329
+ (
330
+ 0,
331
+ 0,
332
+ additional_embeddings_len,
333
+ self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
334
+ ),
335
+ value=0.0,
336
+ )
337
+
338
+ hidden_states = hidden_states + positional_embeddings
339
+
340
+ if attention_mask is not None:
341
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
342
+ attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
343
+ attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
344
+ attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
345
+
346
+ if self.norm_in is not None:
347
+ hidden_states = self.norm_in(hidden_states)
348
+
349
+ for block in self.transformer_blocks:
350
+ hidden_states = block(hidden_states, attention_mask=attention_mask)
351
+
352
+ hidden_states = self.norm_out(hidden_states)
353
+
354
+ if self.prd_embedding is not None:
355
+ hidden_states = hidden_states[:, -1]
356
+ else:
357
+ hidden_states = hidden_states[:, additional_embeddings_len:]
358
+
359
+ predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
360
+
361
+ if not return_dict:
362
+ return (predicted_image_embedding,)
363
+
364
+ return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
365
+
366
+ def post_process_latents(self, prior_latents):
367
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
368
+ return prior_latents