Tony Lian commited on
Commit
1f39cf9
·
1 Parent(s): 58524a7

Add stage 2

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README.md CHANGED
@@ -4,10 +4,12 @@ emoji: 😊
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.32.0
8
  app_file: app.py
9
  pinned: true
10
  tags: [llm, diffusion, grounding, grounded, llm-grounded, text-to-image, language, large language models, layout, generation, generative, customization, personalization, prompting, chatgpt, gpt-3.5, gpt-4]
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
4
  colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: true
10
  tags: [llm, diffusion, grounding, grounded, llm-grounded, text-to-image, language, large language models, layout, generation, generative, customization, personalization, prompting, chatgpt, gpt-3.5, gpt-4]
11
  ---
12
 
13
+ Credits:
14
+
15
+ This space uses code from [diffusers](https://huggingface.co/docs/diffusers/index), [GLIGEN](https://github.com/gligen/GLIGEN), and [layout-guidance](https://github.com/silent-chen/layout-guidance). Using their code means adhering to their license.
app.py CHANGED
@@ -4,13 +4,21 @@ import ast
4
  from matplotlib.patches import Polygon
5
  from matplotlib.collections import PatchCollection
6
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
7
 
8
  box_scale = (512, 512)
9
  size = box_scale
10
 
11
  bg_prompt_text = "Background prompt: "
12
 
13
- simplified_prompt = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Please refer to the example below for the desired format.
14
 
15
  Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
16
  Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
@@ -43,12 +51,20 @@ Background prompt: An oil painting of a living room scene
43
  Caption: {prompt}
44
  Objects: """
45
 
 
 
 
 
 
 
46
  def get_lmd_prompt(prompt):
47
  if prompt == "":
48
- prompt = "A realistic photo of a gray cat and an orange dog on the grass."
49
  return simplified_prompt.format(prompt=prompt)
50
 
51
  def get_layout_image(response):
 
 
52
  gen_boxes, bg_prompt = parse_input(response)
53
  fig = plt.figure(figsize=(8, 8))
54
  # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
@@ -63,6 +79,35 @@ def get_layout_image(response):
63
  plt.clf()
64
  return data
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def parse_input(text=None):
67
  try:
68
  if "Objects: " in text:
@@ -130,30 +175,73 @@ def show_boxes(gen_boxes, bg_prompt=None):
130
 
131
  draw_boxes(anns)
132
 
133
- with gr.Blocks() as g:
134
- gr.HTML("""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
135
- <p>This is a space that allows you to explore the layouts generated by ChatGPT on your own with a simplified set of examples. The layout-to-image generation part will be added.</p>
136
- <p>Read our <a href='https://llm-grounded-diffusion.github.io/'>a brief introduction on our project page</a> or <a href='https://arxiv.org/pdf/2305.13655.pdf'>our work on arxiv</a>. <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</p>
 
 
 
 
137
  <p><b>Tips:</b><p>
138
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
139
  <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
140
- <p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.<p>""")
141
- with gr.Tab("Image Prompt to ChatGPT"):
 
 
 
142
  with gr.Row():
143
  with gr.Column(scale=1):
144
- prompt = gr.Textbox(lines=2, label="Prompt for Layout Generation", placeholder="A realistic photo of a gray cat and an orange dog on the grass.")
145
- greet_btn = gr.Button("Generate Prompt")
146
  with gr.Column(scale=1):
147
- output = gr.Textbox(label="Paste this into ChatGPT (GPT-4 usually gives better results)")
148
- greet_btn.click(fn=get_lmd_prompt, inputs=prompt, outputs=output, api_name="get_lmd_prompt")
 
 
 
 
 
 
 
 
 
149
 
150
- with gr.Tab("Visualize ChatGPT-generated Layout"):
151
  with gr.Row():
152
  with gr.Column(scale=1):
153
- prompt = gr.Textbox(lines=2, label="Paste ChatGPT response here", placeholder="Paste ChatGPT response here")
154
- greet_btn = gr.Button("Visualize Layout")
 
 
 
 
 
 
 
 
155
  with gr.Column(scale=1):
156
- output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img", css="img {width: 300px}")
157
- greet_btn.click(fn=get_layout_image, inputs=prompt, outputs=output, api_name="chatgpt-to-layout")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  g.launch()
 
4
  from matplotlib.patches import Polygon
5
  from matplotlib.collections import PatchCollection
6
  import matplotlib.pyplot as plt
7
+ from utils.parse import filter_boxes
8
+ from generation import run as run_ours
9
+ from baseline import run as run_baseline
10
+ import torch
11
+
12
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
13
+ if torch.cuda.is_available():
14
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
15
 
16
  box_scale = (512, 512)
17
  size = box_scale
18
 
19
  bg_prompt_text = "Background prompt: "
20
 
21
+ simplified_prompt = """You are an intelligent bounding box generator. I will provide you with a caption for a photo, image, or painting. Your task is to generate the bounding boxes for the objects mentioned in the caption, along with a background prompt describing the scene. The images are of size 512x512, and the bounding boxes should not overlap or go beyond the image boundaries. Each bounding box should be in the format of (object name, [top-left x coordinate, top-left y coordinate, box width, box height]) and include exactly one object. Do not put objects that are already provided in the bounding boxes into the background prompt. If needed, you can make reasonable guesses. Generate the object descriptions and background prompts in English even if the caption might not be in English. Please refer to the example below for the desired format.
22
 
23
  Caption: A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky
24
  Objects: [('a green car', [21, 181, 211, 159]), ('a blue truck', [269, 181, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]
 
51
  Caption: {prompt}
52
  Objects: """
53
 
54
+ prompt_placeholder = "A realistic photo of a gray cat and an orange dog on the grass."
55
+
56
+ layout_placeholder = """Caption: A realistic photo of a gray cat and an orange dog on the grass.
57
+ Objects: [('a gray cat', [67, 243, 120, 126]), ('an orange dog', [265, 193, 190, 210])]
58
+ Background prompt: A realistic photo of a grassy area."""
59
+
60
  def get_lmd_prompt(prompt):
61
  if prompt == "":
62
+ prompt = prompt_placeholder
63
  return simplified_prompt.format(prompt=prompt)
64
 
65
  def get_layout_image(response):
66
+ if response == "":
67
+ response = layout_placeholder
68
  gen_boxes, bg_prompt = parse_input(response)
69
  fig = plt.figure(figsize=(8, 8))
70
  # https://stackoverflow.com/questions/7821518/save-plot-to-numpy-array
 
79
  plt.clf()
80
  return data
81
 
82
+ def get_layout_image_gallery(response):
83
+ return [get_layout_image(response)]
84
+
85
+ def get_ours_image(response, seed, fg_seed_start, fg_blending_ratio=0.1, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta=0.3, show_so_imgs=False, scale_boxes=False, gallery=None):
86
+ if response == "":
87
+ response = layout_placeholder
88
+ gen_boxes, bg_prompt = parse_input(response)
89
+ gen_boxes = filter_boxes(gen_boxes, scale_boxes=scale_boxes)
90
+ spec = {
91
+ # prompt is unused
92
+ 'prompt': '',
93
+ 'gen_boxes': gen_boxes,
94
+ 'bg_prompt': bg_prompt
95
+ }
96
+ image_np, so_img_list = run_ours(
97
+ spec, bg_seed=seed, fg_seed_start=fg_seed_start,
98
+ fg_blending_ratio=fg_blending_ratio,frozen_step_ratio=frozen_step_ratio,
99
+ gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta)
100
+ images = [image_np]
101
+ if show_so_imgs:
102
+ images.extend([np.asarray(so_img) for so_img in so_img_list])
103
+ return images
104
+
105
+ def get_baseline_image(prompt, seed):
106
+ if prompt == "":
107
+ prompt = prompt_placeholder
108
+ image_np = run_baseline(prompt, bg_seed=seed)
109
+ return [image_np]
110
+
111
  def parse_input(text=None):
112
  try:
113
  if "Objects: " in text:
 
175
 
176
  draw_boxes(anns)
177
 
178
+ duplicate_html = '<a style="display:inline-block" href="https://huggingface.co/spaces/longlian/llm-grounded-diffusion?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>'
179
+
180
+ with gr.Blocks(
181
+ title="LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models"
182
+ ) as g:
183
+ gr.HTML(f"""<h1>LLM-grounded Diffusion: Enhancing Prompt Understanding of Text-to-Image Diffusion Models with Large Language Models</h1>
184
+ <h2>LLM + Stable Diffusion => better prompt understanding in text2image generation 🤩</h2>
185
+ <h2><a href='https://llm-grounded-diffusion.github.io/'>Project Page</a> | <a href='https://bair.berkeley.edu/blog/2023/05/23/lmd/'>5-minute Blog Post</a> | <a href='https://arxiv.org/pdf/2305.13655.pdf'>ArXiv Paper</a> | <a href='https://github.com/TonyLianLong/LLM-groundedDiffusion'>Github</a> | <a href='https://llm-grounded-diffusion.github.io/#citation'>Cite our work</a> if our ideas inspire you.</h2>
186
  <p><b>Tips:</b><p>
187
  <p>1. If ChatGPT doesn't generate layout, add/remove the trailing space (added by default) and/or use GPT-4.</p>
188
  <p>2. You can perform multi-round specification by giving ChatGPT follow-up requests (e.g., make the object boxes bigger).</p>
189
+ <p>3. You can also try prompts in Simplified Chinese. If you want to try prompts in another language, translate the first line of last example to your language.<p>
190
+ <p>4. Duplicate this space and add GPU to skip the queue and run our model faster. {duplicate_html}</p>
191
+ <br/>
192
+ <p>Implementation note: In this demo, we replace the attention manipulation in our layout-guided Stable Diffusion described in our paper with GLIGEN due to much faster inference speed (<b>FlashAttention supported, no backprop needed</b> during inference). Compared to vanilla GLIGEN, we have better coherence. Other parts of text-to-image pipeline, including single object generation and SAM, remain the same. The settings and examples in the prompt are simplified in this demo.</p>""")
193
+ with gr.Tab("Stage 1. Image Prompt to ChatGPT"):
194
  with gr.Row():
195
  with gr.Column(scale=1):
196
+ prompt = gr.Textbox(lines=2, label="Prompt for Layout Generation", placeholder=prompt_placeholder)
197
+ generate_btn = gr.Button("Generate Prompt")
198
  with gr.Column(scale=1):
199
+ output = gr.Textbox(label="Paste this into ChatGPT (GPT-4 preferred; on Mac, click text and press Command+A and Command+C to copy all)")
200
+ generate_btn.click(fn=get_lmd_prompt, inputs=prompt, outputs=output, api_name="get_lmd_prompt")
201
+
202
+ # with gr.Tab("(Optional) Visualize ChatGPT-generated Layout"):
203
+ # with gr.Row():
204
+ # with gr.Column(scale=1):
205
+ # response = gr.Textbox(lines=5, label="Paste ChatGPT response here", placeholder=layout_placeholder)
206
+ # visualize_btn = gr.Button("Visualize Layout")
207
+ # with gr.Column(scale=1):
208
+ # output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img", css="img {width: 300px}")
209
+ # visualize_btn.click(fn=get_layout_image, inputs=response, outputs=output, api_name="visualize-layout")
210
 
211
+ with gr.Tab("Stage 2 (New). Layout to Image generation"):
212
  with gr.Row():
213
  with gr.Column(scale=1):
214
+ response = gr.Textbox(lines=5, label="Paste ChatGPT response here (no original caption needed)", placeholder=layout_placeholder)
215
+ visualize_btn = gr.Button("Visualize Layout")
216
+ generate_btn = gr.Button("Generate Image from Layout", variant='primary')
217
+ with gr.Accordion("Advanced options", open=False):
218
+ seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
219
+ fg_seed_start = gr.Slider(0, 10000, value=20, step=1, label="Seed for foreground variation")
220
+ fg_blending_ratio = gr.Slider(0, 1, value=0.1, step=0.01, label="Variations added to foreground for single object generation (0: no variation, 1: max variation)")
221
+ frozen_step_ratio = gr.Slider(0, 1, value=0.4, step=0.1, label="Foreground frozen steps ratio (higher: preserve object attributes; lower: higher coherence; set to 0: (almost) equivalent to vanilla GLIGEN except details)")
222
+ gligen_scheduled_sampling_beta = gr.Slider(0, 1, value=0.3, step=0.1, label="GLIGEN guidance steps ratio (the beta value)")
223
+ show_so_imgs = gr.Checkbox(label="Show annotated single object generations", show_label=False)
224
  with gr.Column(scale=1):
225
+ gallery = gr.Gallery(
226
+ label="Generated image", show_label=False, elem_id="gallery"
227
+ ).style(columns=[1], rows=[1], object_fit="contain", preview=True)
228
+ visualize_btn.click(fn=get_layout_image_gallery, inputs=response, outputs=gallery, api_name="visualize-layout")
229
+ generate_btn.click(fn=get_ours_image, inputs=[response, seed, fg_seed_start, fg_blending_ratio, frozen_step_ratio, gligen_scheduled_sampling_beta, show_so_imgs], outputs=gallery, api_name="layout-to-image")
230
+
231
+ with gr.Tab("Baseline: Stable Diffusion"):
232
+ with gr.Row():
233
+ with gr.Column(scale=1):
234
+ sd_prompt = gr.Textbox(lines=2, label="Prompt for baseline SD", placeholder=prompt_placeholder)
235
+ generate_btn = gr.Button("Generate")
236
+ with gr.Accordion("Advanced options", open=False):
237
+ seed = gr.Slider(0, 10000, value=0, step=1, label="Seed")
238
+ # with gr.Column(scale=1):
239
+ # output = gr.Image(shape=(512, 512), elem_classes="img", elem_id="img")
240
+ with gr.Column(scale=1):
241
+ gallery = gr.Gallery(
242
+ label="Generated image", show_label=False, elem_id="gallery2"
243
+ ).style(columns=[1], rows=[1], object_fit="contain", preview=True)
244
+ generate_btn.click(fn=get_baseline_image, inputs=[sd_prompt, seed], outputs=gallery, api_name="baseline")
245
+
246
 
247
  g.launch()
baseline.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Stable Diffusion (1.4)
2
+
3
+ import torch
4
+ import models
5
+ from models import pipelines
6
+ from shared import model_dict
7
+
8
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
9
+
10
+ torch.set_grad_enabled(False)
11
+
12
+ height = 512 # default height of Stable Diffusion
13
+ width = 512 # default width of Stable Diffusion
14
+ num_inference_steps = 20 # Number of denoising steps
15
+ guidance_scale = 7.5 # Scale for classifier-free guidance
16
+ batch_size = 1
17
+
18
+ # h, w
19
+ image_scale = (512, 512)
20
+
21
+ bg_negative = 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate'
22
+
23
+ def run(prompt, bg_seed=1):
24
+ print(f"prompt: {prompt}")
25
+ generator = torch.Generator(models.torch_device).manual_seed(bg_seed)
26
+
27
+ prompts = [prompt]
28
+ input_embeddings = models.encode_prompts(prompts=prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=bg_negative)
29
+
30
+ generator = torch.manual_seed(1) # Seed generator to create the inital latent noise
31
+ latents = models.get_unscaled_latents(batch_size, unet.config.in_channels, height, width, generator, dtype)
32
+
33
+ latents = latents * scheduler.init_noise_sigma
34
+
35
+ pipelines.gligen_enable_fuser(model_dict['unet'], enabled=False)
36
+ _, images = pipelines.generate(
37
+ model_dict, latents, input_embeddings, num_inference_steps,
38
+ guidance_scale=guidance_scale
39
+ )
40
+
41
+ return images[0]
generation.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version = "v3.0"
2
+
3
+ from PIL import Image
4
+ import torch
5
+ import models
6
+ from models import load_sd
7
+ import utils
8
+ from models import pipelines, sam
9
+ from utils import parse, latents
10
+ from shared import model_dict, sam_model_dict
11
+
12
+ verbose = False
13
+
14
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
15
+
16
+ model_dict.update(sam_model_dict)
17
+
18
+
19
+ # Hyperparams
20
+ height = 512 # default height of Stable Diffusion
21
+ width = 512 # default width of Stable Diffusion
22
+ H, W = height // 8, width // 8 # size of the latent
23
+ num_inference_steps = 20 # Number of denoising steps
24
+ guidance_scale = 7.5 # Scale for classifier-free guidance
25
+
26
+ # batch size that is not 1 is not supported
27
+ so_batch_size = 1
28
+ overall_batch_size = 1
29
+
30
+ # discourage masks with confidence below
31
+ discourage_mask_below_confidence = 0.85
32
+
33
+ # discourage masks with iou (with coarse binarized attention mask) below
34
+ discourage_mask_below_coarse_iou = 0.25
35
+
36
+ run_ind = None
37
+
38
+
39
+ def generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings,
40
+ sam_refine_kwargs, gligen_scheduled_sampling_beta=0.3,
41
+ verbose=False, visualize=True):
42
+
43
+ bboxes, phrases, words = [box], [phrase], [word]
44
+
45
+ latents, single_object_images, single_object_pil_images_box_ann, latents_all = pipelines.generate_gligen(
46
+ model_dict, input_latents, input_embeddings, num_inference_steps, bboxes, phrases, gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
47
+ guidance_scale=guidance_scale, return_saved_cross_attn=False,
48
+ return_box_vis=True, save_all_latents=True
49
+ )
50
+
51
+ mask_selected, conf_score_selected = sam.sam_refine_box(sam_input_image=single_object_images[0], box=box, model_dict=model_dict, verbose=verbose, **sam_refine_kwargs)
52
+
53
+ mask_selected_tensor = torch.tensor(mask_selected)
54
+
55
+ return latents_all, mask_selected_tensor, single_object_pil_images_box_ann[0]
56
+
57
+ def get_masked_latents_all_list(so_prompt_phrase_word_box_list, input_latents_list, so_input_embeddings, verbose=False, **kwargs):
58
+ latents_all_list, mask_tensor_list, so_img_list = [], [], []
59
+
60
+ if not so_prompt_phrase_word_box_list:
61
+ return latents_all_list, mask_tensor_list
62
+
63
+ so_uncond_embeddings, so_cond_embeddings = so_input_embeddings
64
+
65
+ for idx, ((prompt, phrase, word, box), input_latents) in enumerate(zip(so_prompt_phrase_word_box_list, input_latents_list)):
66
+ so_current_cond_embeddings = so_cond_embeddings[idx:idx+1]
67
+ so_current_text_embeddings = torch.cat([so_uncond_embeddings, so_current_cond_embeddings], dim=0)
68
+ so_current_input_embeddings = so_current_text_embeddings, so_uncond_embeddings, so_current_cond_embeddings
69
+
70
+ latents_all, mask_tensor, so_img = generate_single_object_with_box(prompt, box, phrase, word, input_latents, input_embeddings=so_current_input_embeddings, verbose=verbose, **kwargs)
71
+ latents_all_list.append(latents_all)
72
+ mask_tensor_list.append(mask_tensor)
73
+ so_img_list.append(so_img)
74
+
75
+ return latents_all_list, mask_tensor_list, so_img_list
76
+
77
+
78
+ # Note: need to keep the supervision, especially the box corrdinates, corresponds to each other in single object and overall.
79
+
80
+ def run(
81
+ spec, bg_seed = 1, fg_seed_start = 20, frozen_step_ratio=0.4, gligen_scheduled_sampling_beta = 0.3,
82
+ so_center_box = False, fg_blending_ratio = 0.1, so_horizontal_center_only = True,
83
+ align_with_overall_bboxes = False, horizontal_shift_only = True
84
+ ):
85
+ """
86
+ so_center_box: using centered box in single object generation
87
+ so_horizontal_center_only: move to the center horizontally only
88
+
89
+ align_with_overall_bboxes: Align the center of the mask, latents, and cross-attention with the center of the box in overall bboxes
90
+ horizontal_shift_only: only shift horizontally for the alignment of mask, latents, and cross-attention
91
+ """
92
+
93
+ print("generation:", spec, bg_seed, fg_seed_start, frozen_step_ratio, gligen_scheduled_sampling_beta)
94
+
95
+ frozen_step_ratio = min(max(frozen_step_ratio, 0.), 1.)
96
+ frozen_steps = int(num_inference_steps * frozen_step_ratio)
97
+
98
+ if True:
99
+ so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes = parse.convert_spec(spec, height, width, verbose=verbose)
100
+
101
+ overall_phrases, overall_words, overall_bboxes = [item[0] for item in overall_phrases_words_bboxes], [item[1] for item in overall_phrases_words_bboxes], [item[2] for item in overall_phrases_words_bboxes]
102
+
103
+ # The so box is centered but the overall boxes are not (since we need to place to the right place).
104
+ if so_center_box:
105
+ so_prompt_phrase_word_box_list = [(prompt, phrase, word, utils.get_centered_box(bbox, horizontal_center_only=so_horizontal_center_only)) for prompt, phrase, word, bbox in so_prompt_phrase_word_box_list]
106
+ if verbose:
107
+ print(f"centered so_prompt_phrase_word_box_list: {so_prompt_phrase_word_box_list}")
108
+ so_boxes = [item[-1] for item in so_prompt_phrase_word_box_list]
109
+
110
+ if True:
111
+ so_negative_prompt = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate, two, many, group, occlusion, occluded, side, border, collate"
112
+ overall_negative_prompt = "artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate"
113
+ else:
114
+ so_negative_prompt = ""
115
+ overall_negative_prompt = ""
116
+
117
+ sam_refine_kwargs = dict(
118
+ discourage_mask_below_confidence=discourage_mask_below_confidence, discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
119
+ height=height, width=width, H=H, W=W
120
+ )
121
+
122
+
123
+ # Note that so and overall use different negative prompts
124
+
125
+ so_prompts = [item[0] for item in so_prompt_phrase_word_box_list]
126
+ if so_prompts:
127
+ so_input_embeddings = models.encode_prompts(prompts=so_prompts, tokenizer=tokenizer, text_encoder=text_encoder, negative_prompt=so_negative_prompt, one_uncond_input_only=True)
128
+ else:
129
+ so_input_embeddings = []
130
+
131
+ overall_input_embeddings = models.encode_prompts(prompts=[overall_prompt], tokenizer=tokenizer, negative_prompt=overall_negative_prompt, text_encoder=text_encoder)
132
+
133
+
134
+
135
+
136
+ input_latents_list, latents_bg = latents.get_input_latents_list(
137
+ model_dict, bg_seed=bg_seed, fg_seed_start=fg_seed_start,
138
+ so_boxes=so_boxes, fg_blending_ratio=fg_blending_ratio, height=height, width=width, verbose=False
139
+ )
140
+ latents_all_list, mask_tensor_list, so_img_list = get_masked_latents_all_list(
141
+ so_prompt_phrase_word_box_list, input_latents_list,
142
+ gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
143
+ sam_refine_kwargs=sam_refine_kwargs, so_input_embeddings=so_input_embeddings, verbose=verbose
144
+ )
145
+
146
+
147
+
148
+ composed_latents, foreground_indices, offset_list = latents.compose_latents_with_alignment(
149
+ model_dict, latents_all_list, mask_tensor_list, num_inference_steps,
150
+ overall_batch_size, height, width, latents_bg=latents_bg,
151
+ align_with_overall_bboxes=align_with_overall_bboxes, overall_bboxes=overall_bboxes,
152
+ horizontal_shift_only=horizontal_shift_only
153
+ )
154
+
155
+ overall_bboxes_flattened, overall_phrases_flattened = [], []
156
+ for overall_bboxes_item, overall_phrase in zip(overall_bboxes, overall_phrases):
157
+ for overall_bbox in overall_bboxes_item:
158
+ overall_bboxes_flattened.append(overall_bbox)
159
+ overall_phrases_flattened.append(overall_phrase)
160
+
161
+ # Generate with composed latents
162
+
163
+ # Foreground should be frozen
164
+ frozen_mask = foreground_indices != 0
165
+
166
+ regen_latents, images = pipelines.generate_gligen(
167
+ model_dict, composed_latents, overall_input_embeddings, num_inference_steps,
168
+ overall_bboxes_flattened, overall_phrases_flattened, guidance_scale=guidance_scale,
169
+ gligen_scheduled_sampling_beta=gligen_scheduled_sampling_beta,
170
+ frozen_steps=frozen_steps, frozen_mask=frozen_mask
171
+ )
172
+
173
+ print(f"Generation with spatial guidance from input latents and first {frozen_steps} steps frozen (directly from the composed latents input)")
174
+ print("Generation from composed latents (with semantic guidance)")
175
+
176
+ # display(Image.fromarray(images[0]), "img", run_ind)
177
+
178
+ return images[0], so_img_list
179
+
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import *
models/attention.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import maybe_allow_in_graph
21
+ from .attention_processor import Attention
22
+ from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
23
+
24
+ # https://github.com/gligen/diffusers/blob/23a9a0fab1b48752c7b9bcc98f6fe3b1d8fa7990/src/diffusers/models/attention.py
25
+ class GatedSelfAttentionDense(nn.Module):
26
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
27
+ super().__init__()
28
+
29
+ # we need a linear projection since we need cat visual feature and obj feature
30
+ self.linear = nn.Linear(context_dim, query_dim)
31
+
32
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
33
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
34
+
35
+ self.norm1 = nn.LayerNorm(query_dim)
36
+ self.norm2 = nn.LayerNorm(query_dim)
37
+
38
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
39
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
40
+
41
+ self.enabled = True
42
+
43
+ def forward(self, x, objs, fuser_attn_kwargs={}):
44
+ if not self.enabled:
45
+ return x
46
+
47
+ n_visual = x.shape[1]
48
+ objs = self.linear(objs)
49
+
50
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)), **fuser_attn_kwargs)[:, :n_visual, :]
51
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
52
+
53
+ return x
54
+
55
+ @maybe_allow_in_graph
56
+ class BasicTransformerBlock(nn.Module):
57
+ r"""
58
+ A basic Transformer block.
59
+
60
+ Parameters:
61
+ dim (`int`): The number of channels in the input and output.
62
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
63
+ attention_head_dim (`int`): The number of channels in each head.
64
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
65
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
66
+ only_cross_attention (`bool`, *optional*):
67
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
68
+ double_self_attention (`bool`, *optional*):
69
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
70
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
71
+ num_embeds_ada_norm (:
72
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
73
+ attention_bias (:
74
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ dim: int,
80
+ num_attention_heads: int,
81
+ attention_head_dim: int,
82
+ dropout=0.0,
83
+ cross_attention_dim: Optional[int] = None,
84
+ activation_fn: str = "geglu",
85
+ num_embeds_ada_norm: Optional[int] = None,
86
+ attention_bias: bool = False,
87
+ only_cross_attention: bool = False,
88
+ double_self_attention: bool = False,
89
+ upcast_attention: bool = False,
90
+ norm_elementwise_affine: bool = True,
91
+ norm_type: str = "layer_norm",
92
+ final_dropout: bool = False,
93
+ use_gated_attention: bool = False,
94
+ ):
95
+ super().__init__()
96
+ self.only_cross_attention = only_cross_attention
97
+
98
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
99
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
100
+
101
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
102
+ raise ValueError(
103
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
104
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
105
+ )
106
+
107
+ # Define 3 blocks. Each block has its own normalization layer.
108
+ # 1. Self-Attn
109
+ if self.use_ada_layer_norm:
110
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
111
+ elif self.use_ada_layer_norm_zero:
112
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
113
+ else:
114
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
115
+ self.attn1 = Attention(
116
+ query_dim=dim,
117
+ heads=num_attention_heads,
118
+ dim_head=attention_head_dim,
119
+ dropout=dropout,
120
+ bias=attention_bias,
121
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
122
+ upcast_attention=upcast_attention,
123
+ )
124
+
125
+ # 2. Cross-Attn
126
+ if cross_attention_dim is not None or double_self_attention:
127
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
128
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
129
+ # the second cross attention block.
130
+ self.norm2 = (
131
+ AdaLayerNorm(dim, num_embeds_ada_norm)
132
+ if self.use_ada_layer_norm
133
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
134
+ )
135
+ self.attn2 = Attention(
136
+ query_dim=dim,
137
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
138
+ heads=num_attention_heads,
139
+ dim_head=attention_head_dim,
140
+ dropout=dropout,
141
+ bias=attention_bias,
142
+ upcast_attention=upcast_attention,
143
+ ) # is self-attn if encoder_hidden_states is none
144
+ else:
145
+ self.norm2 = None
146
+ self.attn2 = None
147
+
148
+ # 3. Feed-forward
149
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
150
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
151
+
152
+ # 4. Fuser
153
+ if use_gated_attention:
154
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
155
+
156
+ def forward(
157
+ self,
158
+ hidden_states: torch.FloatTensor,
159
+ attention_mask: Optional[torch.FloatTensor] = None,
160
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
161
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
162
+ timestep: Optional[torch.LongTensor] = None,
163
+ cross_attention_kwargs: Dict[str, Any] = None,
164
+ class_labels: Optional[torch.LongTensor] = None,
165
+ return_cross_attention_probs: bool = None,
166
+ ):
167
+ # Notice that normalization is always applied before the real computation in the following blocks.
168
+
169
+ # 0. Prepare GLIGEN inputs
170
+ if 'gligen' in cross_attention_kwargs:
171
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
172
+ gligen_kwargs = cross_attention_kwargs.pop('gligen', None)
173
+ else:
174
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
175
+ gligen_kwargs = None
176
+
177
+ # 1. Self-Attention
178
+ if self.use_ada_layer_norm:
179
+ norm_hidden_states = self.norm1(hidden_states, timestep)
180
+ elif self.use_ada_layer_norm_zero:
181
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
182
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
183
+ )
184
+ else:
185
+ norm_hidden_states = self.norm1(hidden_states)
186
+
187
+ attn_output = self.attn1(
188
+ norm_hidden_states,
189
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
190
+ attention_mask=attention_mask,
191
+ **cross_attention_kwargs,
192
+ )
193
+ if self.use_ada_layer_norm_zero:
194
+ attn_output = gate_msa.unsqueeze(1) * attn_output
195
+ hidden_states = attn_output + hidden_states
196
+
197
+ # 1.5 GLIGEN Control
198
+ if gligen_kwargs is not None:
199
+ # print(gligen_kwargs)
200
+ hidden_states = self.fuser(hidden_states, gligen_kwargs['objs'], fuser_attn_kwargs=gligen_kwargs.get("fuser_attn_kwargs", {}))
201
+ # 1.5 ends
202
+
203
+ # 2. Cross-Attention
204
+ if self.attn2 is not None:
205
+ norm_hidden_states = (
206
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
207
+ )
208
+
209
+ attn_output = self.attn2(
210
+ norm_hidden_states,
211
+ encoder_hidden_states=encoder_hidden_states,
212
+ attention_mask=encoder_attention_mask,
213
+ return_attntion_probs=return_cross_attention_probs,
214
+ **cross_attention_kwargs,
215
+ )
216
+
217
+ if return_cross_attention_probs:
218
+ attn_output, cross_attention_probs = attn_output
219
+
220
+ hidden_states = attn_output + hidden_states
221
+
222
+ # 3. Feed-forward
223
+ norm_hidden_states = self.norm3(hidden_states)
224
+
225
+ if self.use_ada_layer_norm_zero:
226
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
227
+
228
+ ff_output = self.ff(norm_hidden_states)
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
232
+
233
+ hidden_states = ff_output + hidden_states
234
+
235
+ if return_cross_attention_probs and self.attn2 is not None:
236
+ return hidden_states, cross_attention_probs
237
+ return hidden_states
238
+
239
+
240
+ class FeedForward(nn.Module):
241
+ r"""
242
+ A feed-forward layer.
243
+
244
+ Parameters:
245
+ dim (`int`): The number of channels in the input.
246
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
247
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
248
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
249
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
250
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
251
+ """
252
+
253
+ def __init__(
254
+ self,
255
+ dim: int,
256
+ dim_out: Optional[int] = None,
257
+ mult: int = 4,
258
+ dropout: float = 0.0,
259
+ activation_fn: str = "geglu",
260
+ final_dropout: bool = False,
261
+ ):
262
+ super().__init__()
263
+ inner_dim = int(dim * mult)
264
+ dim_out = dim_out if dim_out is not None else dim
265
+
266
+ if activation_fn == "gelu":
267
+ act_fn = GELU(dim, inner_dim)
268
+ if activation_fn == "gelu-approximate":
269
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
270
+ elif activation_fn == "geglu":
271
+ act_fn = GEGLU(dim, inner_dim)
272
+ elif activation_fn == "geglu-approximate":
273
+ act_fn = ApproximateGELU(dim, inner_dim)
274
+
275
+ self.net = nn.ModuleList([])
276
+ # project in
277
+ self.net.append(act_fn)
278
+ # project dropout
279
+ self.net.append(nn.Dropout(dropout))
280
+ # project out
281
+ self.net.append(nn.Linear(inner_dim, dim_out))
282
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
283
+ if final_dropout:
284
+ self.net.append(nn.Dropout(dropout))
285
+
286
+ def forward(self, hidden_states):
287
+ for module in self.net:
288
+ hidden_states = module(hidden_states)
289
+ return hidden_states
290
+
291
+
292
+ class GELU(nn.Module):
293
+ r"""
294
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
295
+ """
296
+
297
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
298
+ super().__init__()
299
+ self.proj = nn.Linear(dim_in, dim_out)
300
+ self.approximate = approximate
301
+
302
+ def gelu(self, gate):
303
+ if gate.device.type != "mps":
304
+ return F.gelu(gate, approximate=self.approximate)
305
+ # mps: gelu is not implemented for float16
306
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
307
+
308
+ def forward(self, hidden_states):
309
+ hidden_states = self.proj(hidden_states)
310
+ hidden_states = self.gelu(hidden_states)
311
+ return hidden_states
312
+
313
+
314
+ class GEGLU(nn.Module):
315
+ r"""
316
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
317
+
318
+ Parameters:
319
+ dim_in (`int`): The number of channels in the input.
320
+ dim_out (`int`): The number of channels in the output.
321
+ """
322
+
323
+ def __init__(self, dim_in: int, dim_out: int):
324
+ super().__init__()
325
+ self.proj = nn.Linear(dim_in, dim_out * 2)
326
+
327
+ def gelu(self, gate):
328
+ if gate.device.type != "mps":
329
+ return F.gelu(gate)
330
+ # mps: gelu is not implemented for float16
331
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
332
+
333
+ def forward(self, hidden_states):
334
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
335
+ return hidden_states * self.gelu(gate)
336
+
337
+
338
+ class ApproximateGELU(nn.Module):
339
+ """
340
+ The approximate form of Gaussian Error Linear Unit (GELU)
341
+
342
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
343
+ """
344
+
345
+ def __init__(self, dim_in: int, dim_out: int):
346
+ super().__init__()
347
+ self.proj = nn.Linear(dim_in, dim_out)
348
+
349
+ def forward(self, x):
350
+ x = self.proj(x)
351
+ return x * torch.sigmoid(1.702 * x)
352
+
353
+
354
+ class AdaLayerNorm(nn.Module):
355
+ """
356
+ Norm layer modified to incorporate timestep embeddings.
357
+ """
358
+
359
+ def __init__(self, embedding_dim, num_embeddings):
360
+ super().__init__()
361
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
362
+ self.silu = nn.SiLU()
363
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
364
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
365
+
366
+ def forward(self, x, timestep):
367
+ emb = self.linear(self.silu(self.emb(timestep)))
368
+ scale, shift = torch.chunk(emb, 2)
369
+ x = self.norm(x) * (1 + scale) + shift
370
+ return x
371
+
372
+
373
+ class AdaLayerNormZero(nn.Module):
374
+ """
375
+ Norm layer adaptive layer norm zero (adaLN-Zero).
376
+ """
377
+
378
+ def __init__(self, embedding_dim, num_embeddings):
379
+ super().__init__()
380
+
381
+ self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
382
+
383
+ self.silu = nn.SiLU()
384
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
385
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
386
+
387
+ def forward(self, x, timestep, class_labels, hidden_dtype=None):
388
+ emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
389
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
390
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
391
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
392
+
models/attention_processor.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import warnings
15
+ from typing import Callable, Optional, Union
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import deprecate, logging, maybe_allow_in_graph
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+ @maybe_allow_in_graph
26
+ class Attention(nn.Module):
27
+ r"""
28
+ A cross attention layer.
29
+
30
+ Parameters:
31
+ query_dim (`int`): The number of channels in the query.
32
+ cross_attention_dim (`int`, *optional*):
33
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
34
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
35
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
36
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
37
+ bias (`bool`, *optional*, defaults to False):
38
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ query_dim: int,
44
+ cross_attention_dim: Optional[int] = None,
45
+ heads: int = 8,
46
+ dim_head: int = 64,
47
+ dropout: float = 0.0,
48
+ bias=False,
49
+ upcast_attention: bool = False,
50
+ upcast_softmax: bool = False,
51
+ cross_attention_norm: Optional[str] = None,
52
+ cross_attention_norm_num_groups: int = 32,
53
+ added_kv_proj_dim: Optional[int] = None,
54
+ norm_num_groups: Optional[int] = None,
55
+ spatial_norm_dim: Optional[int] = None,
56
+ out_bias: bool = True,
57
+ scale_qk: bool = True,
58
+ only_cross_attention: bool = False,
59
+ eps: float = 1e-5,
60
+ rescale_output_factor: float = 1.0,
61
+ residual_connection: bool = False,
62
+ _from_deprecated_attn_block=False,
63
+ processor: Optional["AttnProcessor"] = None,
64
+ ):
65
+ super().__init__()
66
+ inner_dim = dim_head * heads
67
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
68
+ self.upcast_attention = upcast_attention
69
+ self.upcast_softmax = upcast_softmax
70
+ self.rescale_output_factor = rescale_output_factor
71
+ self.residual_connection = residual_connection
72
+
73
+ # we make use of this private variable to know whether this class is loaded
74
+ # with an deprecated state dict so that we can convert it on the fly
75
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
76
+
77
+ self.scale_qk = scale_qk
78
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
79
+
80
+ self.heads = heads
81
+ # for slice_size > 0 the attention score computation
82
+ # is split across the batch axis to save memory
83
+ # You can set slice_size with `set_attention_slice`
84
+ self.sliceable_head_dim = heads
85
+
86
+ self.added_kv_proj_dim = added_kv_proj_dim
87
+ self.only_cross_attention = only_cross_attention
88
+
89
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
90
+ raise ValueError(
91
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
92
+ )
93
+
94
+ if norm_num_groups is not None:
95
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
96
+ else:
97
+ self.group_norm = None
98
+
99
+ if spatial_norm_dim is not None:
100
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
101
+ else:
102
+ self.spatial_norm = None
103
+
104
+ if cross_attention_norm is None:
105
+ self.norm_cross = None
106
+ elif cross_attention_norm == "layer_norm":
107
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
108
+ elif cross_attention_norm == "group_norm":
109
+ if self.added_kv_proj_dim is not None:
110
+ # The given `encoder_hidden_states` are initially of shape
111
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
112
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
113
+ # before the projection, so we need to use `added_kv_proj_dim` as
114
+ # the number of channels for the group norm.
115
+ norm_cross_num_channels = added_kv_proj_dim
116
+ else:
117
+ norm_cross_num_channels = cross_attention_dim
118
+
119
+ self.norm_cross = nn.GroupNorm(
120
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
121
+ )
122
+ else:
123
+ raise ValueError(
124
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
125
+ )
126
+
127
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
128
+
129
+ if not self.only_cross_attention:
130
+ # only relevant for the `AddedKVProcessor` classes
131
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
132
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
133
+ else:
134
+ self.to_k = None
135
+ self.to_v = None
136
+
137
+ if self.added_kv_proj_dim is not None:
138
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
139
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
140
+
141
+ self.to_out = nn.ModuleList([])
142
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
143
+ self.to_out.append(nn.Dropout(dropout))
144
+
145
+ # set attention processor
146
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
147
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
148
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
149
+ if processor is None:
150
+ # processor = (
151
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
152
+ # )
153
+ # Note: efficient attention is not used. We can use efficient attention to speed up.
154
+ processor = AttnProcessor()
155
+ self.set_processor(processor)
156
+
157
+ def set_processor(self, processor: "AttnProcessor"):
158
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
159
+ # pop `processor` from `self._modules`
160
+ if (
161
+ hasattr(self, "processor")
162
+ and isinstance(self.processor, torch.nn.Module)
163
+ and not isinstance(processor, torch.nn.Module)
164
+ ):
165
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
166
+ self._modules.pop("processor")
167
+
168
+ self.processor = processor
169
+
170
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, return_attntion_probs=False, **cross_attention_kwargs):
171
+ # The `Attention` class can call different attention processors / attention functions
172
+ # here we simply pass along all tensors to the selected processor class
173
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
174
+ return self.processor(
175
+ self,
176
+ hidden_states,
177
+ encoder_hidden_states=encoder_hidden_states,
178
+ attention_mask=attention_mask,
179
+ return_attntion_probs=return_attntion_probs,
180
+ **cross_attention_kwargs,
181
+ )
182
+
183
+ def batch_to_head_dim(self, tensor):
184
+ head_size = self.heads
185
+ batch_size, seq_len, dim = tensor.shape
186
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
187
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
188
+ return tensor
189
+
190
+ def head_to_batch_dim(self, tensor, out_dim=3):
191
+ head_size = self.heads
192
+ batch_size, seq_len, dim = tensor.shape
193
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
194
+ tensor = tensor.permute(0, 2, 1, 3)
195
+
196
+ if out_dim == 3:
197
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
198
+
199
+ return tensor
200
+
201
+ def get_attention_scores(self, query, key, attention_mask=None):
202
+ dtype = query.dtype
203
+ if self.upcast_attention:
204
+ query = query.float()
205
+ key = key.float()
206
+
207
+ if attention_mask is None:
208
+ baddbmm_input = torch.empty(
209
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
210
+ )
211
+ beta = 0
212
+ else:
213
+ baddbmm_input = attention_mask
214
+ beta = 1
215
+
216
+ attention_scores = torch.baddbmm(
217
+ baddbmm_input,
218
+ query,
219
+ key.transpose(-1, -2),
220
+ beta=beta,
221
+ alpha=self.scale,
222
+ )
223
+ del baddbmm_input
224
+
225
+ if self.upcast_softmax:
226
+ attention_scores = attention_scores.float()
227
+
228
+ attention_probs = attention_scores.softmax(dim=-1)
229
+ del attention_scores
230
+
231
+ attention_probs = attention_probs.to(dtype)
232
+
233
+ return attention_probs
234
+
235
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
236
+ if batch_size is None:
237
+ deprecate(
238
+ "batch_size=None",
239
+ "0.0.15",
240
+ (
241
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
242
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
243
+ " `prepare_attention_mask` when preparing the attention_mask."
244
+ ),
245
+ )
246
+ batch_size = 1
247
+
248
+ head_size = self.heads
249
+ if attention_mask is None:
250
+ return attention_mask
251
+
252
+ current_length: int = attention_mask.shape[-1]
253
+ if current_length != target_length:
254
+ if attention_mask.device.type == "mps":
255
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
256
+ # Instead, we can manually construct the padding tensor.
257
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
258
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
259
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
260
+ else:
261
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
262
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
263
+ # remaining_length: int = target_length - current_length
264
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
265
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
266
+
267
+ if out_dim == 3:
268
+ if attention_mask.shape[0] < batch_size * head_size:
269
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
270
+ elif out_dim == 4:
271
+ attention_mask = attention_mask.unsqueeze(1)
272
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
273
+
274
+ return attention_mask
275
+
276
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
277
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
278
+
279
+ if isinstance(self.norm_cross, nn.LayerNorm):
280
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
281
+ elif isinstance(self.norm_cross, nn.GroupNorm):
282
+ # Group norm norms along the channels dimension and expects
283
+ # input to be in the shape of (N, C, *). In this case, we want
284
+ # to norm along the hidden dimension, so we need to move
285
+ # (batch_size, sequence_length, hidden_size) ->
286
+ # (batch_size, hidden_size, sequence_length)
287
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
288
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
289
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
290
+ else:
291
+ assert False
292
+
293
+ return encoder_hidden_states
294
+
295
+
296
+ class AttnProcessor:
297
+ r"""
298
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
299
+ """
300
+
301
+ def __init__(self):
302
+ if not hasattr(F, "scaled_dot_product_attention"):
303
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
304
+
305
+ def __call_fast__(
306
+ self,
307
+ attn: Attention,
308
+ hidden_states,
309
+ encoder_hidden_states=None,
310
+ attention_mask=None,
311
+ temb=None,
312
+ ):
313
+ residual = hidden_states
314
+
315
+ if attn.spatial_norm is not None:
316
+ hidden_states = attn.spatial_norm(hidden_states, temb)
317
+
318
+ input_ndim = hidden_states.ndim
319
+
320
+ if input_ndim == 4:
321
+ batch_size, channel, height, width = hidden_states.shape
322
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
323
+
324
+ batch_size, sequence_length, _ = (
325
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
326
+ )
327
+ inner_dim = hidden_states.shape[-1]
328
+
329
+ if attention_mask is not None:
330
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
331
+ # scaled_dot_product_attention expects attention_mask shape to be
332
+ # (batch, heads, source_length, target_length)
333
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
334
+
335
+ if attn.group_norm is not None:
336
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
337
+
338
+ query = attn.to_q(hidden_states)
339
+
340
+ if encoder_hidden_states is None:
341
+ encoder_hidden_states = hidden_states
342
+ elif attn.norm_cross:
343
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
344
+
345
+ key = attn.to_k(encoder_hidden_states)
346
+ value = attn.to_v(encoder_hidden_states)
347
+
348
+ head_dim = inner_dim // attn.heads
349
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
350
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
351
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
352
+
353
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
354
+ # TODO: add support for attn.scale when we move to Torch 2.1
355
+ hidden_states = F.scaled_dot_product_attention(
356
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
357
+ )
358
+
359
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
360
+ hidden_states = hidden_states.to(query.dtype)
361
+
362
+ # linear proj
363
+ hidden_states = attn.to_out[0](hidden_states)
364
+ # dropout
365
+ hidden_states = attn.to_out[1](hidden_states)
366
+
367
+ if input_ndim == 4:
368
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
369
+
370
+ if attn.residual_connection:
371
+ hidden_states = hidden_states + residual
372
+
373
+ hidden_states = hidden_states / attn.rescale_output_factor
374
+
375
+ return hidden_states
376
+
377
+ def __call__(
378
+ self,
379
+ attn: Attention,
380
+ hidden_states,
381
+ encoder_hidden_states=None,
382
+ attention_mask=None,
383
+ temb=None,
384
+ return_attntion_probs=False,
385
+ attn_key=None,
386
+ attn_process_fn=None,
387
+ return_cond_ca_only=False,
388
+ return_token_ca_only=None,
389
+ offload_cross_attn_to_cpu=False,
390
+ save_attn_to_dict=None,
391
+ save_keys=None,
392
+ enable_flash_attn=True,
393
+ ):
394
+ """
395
+ attn_key: current key (a tuple of hierarchy index (up/mid/down, stage id, block id, sub-block id), sub block id should always be 0 in SD UNet)
396
+ save_attn_to_dict: pass in a dict to save to dict
397
+ """
398
+ cross_attn = encoder_hidden_states is not None
399
+
400
+ if (not cross_attn) or (
401
+ (attn_process_fn is None)
402
+ and not (save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys)))
403
+ and not return_attntion_probs):
404
+ with torch.backends.cuda.sdp_kernel(enable_flash=enable_flash_attn, enable_math=True, enable_mem_efficient=enable_flash_attn):
405
+ return self.__call_fast__(attn, hidden_states, encoder_hidden_states, attention_mask, temb)
406
+
407
+ residual = hidden_states
408
+
409
+ if attn.spatial_norm is not None:
410
+ hidden_states = attn.spatial_norm(hidden_states, temb)
411
+
412
+ input_ndim = hidden_states.ndim
413
+
414
+ if input_ndim == 4:
415
+ batch_size, channel, height, width = hidden_states.shape
416
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
417
+
418
+ batch_size, sequence_length, _ = (
419
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
420
+ )
421
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
422
+
423
+ if attn.group_norm is not None:
424
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
425
+
426
+ query = attn.to_q(hidden_states)
427
+
428
+ if encoder_hidden_states is None:
429
+ encoder_hidden_states = hidden_states
430
+ elif attn.norm_cross:
431
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
432
+
433
+ key = attn.to_k(encoder_hidden_states)
434
+ value = attn.to_v(encoder_hidden_states)
435
+
436
+ query = attn.head_to_batch_dim(query)
437
+ key = attn.head_to_batch_dim(key)
438
+ value = attn.head_to_batch_dim(value)
439
+
440
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
441
+ # Currently only process cross-attention
442
+ if attn_process_fn is not None and cross_attn:
443
+ attention_probs_before_process = attention_probs.clone()
444
+ attention_probs = attn_process_fn(attention_probs, query, key, value, attn_key=attn_key, cross_attn=cross_attn, batch_size=batch_size, heads=attn.heads)
445
+ else:
446
+ attention_probs_before_process = attention_probs
447
+ hidden_states = torch.bmm(attention_probs, value)
448
+ hidden_states = attn.batch_to_head_dim(hidden_states)
449
+
450
+ # linear proj
451
+ hidden_states = attn.to_out[0](hidden_states)
452
+ # dropout
453
+ hidden_states = attn.to_out[1](hidden_states)
454
+
455
+ if input_ndim == 4:
456
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
457
+
458
+ if attn.residual_connection:
459
+ hidden_states = hidden_states + residual
460
+
461
+ hidden_states = hidden_states / attn.rescale_output_factor
462
+
463
+ if return_attntion_probs or save_attn_to_dict is not None:
464
+ # Recover batch dimension: (batch_size, heads, flattened_2d, text_tokens)
465
+ attention_probs_unflattened = attention_probs_before_process.unflatten(dim=0, sizes=(batch_size, attn.heads))
466
+ if return_token_ca_only is not None:
467
+ # (batch size, n heads, 2d dimension, num text tokens)
468
+ if isinstance(return_token_ca_only, int):
469
+ # return_token_ca_only: an integer
470
+ attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only:return_token_ca_only+1]
471
+ else:
472
+ # return_token_ca_only: A 1d index tensor
473
+ attention_probs_unflattened = attention_probs_unflattened[:, :, :, return_token_ca_only]
474
+ if return_cond_ca_only:
475
+ assert batch_size % 2 == 0, f"Samples are not in pairs: {batch_size} samples"
476
+ attention_probs_unflattened = attention_probs_unflattened[batch_size // 2:]
477
+ if offload_cross_attn_to_cpu:
478
+ attention_probs_unflattened = attention_probs_unflattened.cpu()
479
+ if save_attn_to_dict is not None and (save_keys is None or (tuple(attn_key) in save_keys)):
480
+ save_attn_to_dict[tuple(attn_key)] = attention_probs_unflattened
481
+ if return_attntion_probs:
482
+ return hidden_states, attention_probs_unflattened
483
+ return hidden_states
484
+
485
+ # For typing
486
+ AttentionProcessor = AttnProcessor
487
+
488
+ class SpatialNorm(nn.Module):
489
+ """
490
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
491
+ """
492
+
493
+ def __init__(
494
+ self,
495
+ f_channels,
496
+ zq_channels,
497
+ ):
498
+ super().__init__()
499
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
500
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
501
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
502
+
503
+ def forward(self, f, zq):
504
+ f_size = f.shape[-2:]
505
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
506
+ norm_f = self.norm_layer(f)
507
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
508
+ return new_f
models/modeling_utils.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import inspect
18
+ import itertools
19
+ import os
20
+ from functools import partial
21
+ from typing import Any, Callable, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from torch import Tensor, device
25
+
26
+ from diffusers import __version__
27
+ from diffusers.utils import (
28
+ CONFIG_NAME,
29
+ DIFFUSERS_CACHE,
30
+ FLAX_WEIGHTS_NAME,
31
+ HF_HUB_OFFLINE,
32
+ SAFETENSORS_WEIGHTS_NAME,
33
+ WEIGHTS_NAME,
34
+ _add_variant,
35
+ _get_model_file,
36
+ deprecate,
37
+ is_accelerate_available,
38
+ is_safetensors_available,
39
+ is_torch_version,
40
+ logging,
41
+ )
42
+
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ if is_torch_version(">=", "1.9.0"):
48
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
49
+ else:
50
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
51
+
52
+
53
+ if is_accelerate_available():
54
+ import accelerate
55
+ from accelerate.utils import set_module_tensor_to_device
56
+ from accelerate.utils.versions import is_torch_version
57
+
58
+ if is_safetensors_available():
59
+ import safetensors
60
+
61
+
62
+ def get_parameter_device(parameter: torch.nn.Module):
63
+ try:
64
+ parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
65
+ return next(parameters_and_buffers).device
66
+ except StopIteration:
67
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
68
+
69
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
70
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
71
+ return tuples
72
+
73
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
74
+ first_tuple = next(gen)
75
+ return first_tuple[1].device
76
+
77
+
78
+ def get_parameter_dtype(parameter: torch.nn.Module):
79
+ try:
80
+ params = tuple(parameter.parameters())
81
+ if len(params) > 0:
82
+ return params[0].dtype
83
+
84
+ buffers = tuple(parameter.buffers())
85
+ if len(buffers) > 0:
86
+ return buffers[0].dtype
87
+
88
+ except StopIteration:
89
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
90
+
91
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
92
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
93
+ return tuples
94
+
95
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
96
+ first_tuple = next(gen)
97
+ return first_tuple[1].dtype
98
+
99
+
100
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
101
+ """
102
+ Reads a checkpoint file, returning properly formatted errors if they arise.
103
+ """
104
+ try:
105
+ if os.path.basename(checkpoint_file) == _add_variant(WEIGHTS_NAME, variant):
106
+ return torch.load(checkpoint_file, map_location="cpu")
107
+ else:
108
+ return safetensors.torch.load_file(checkpoint_file, device="cpu")
109
+ except Exception as e:
110
+ try:
111
+ with open(checkpoint_file) as f:
112
+ if f.read().startswith("version"):
113
+ raise OSError(
114
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
115
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
116
+ "you cloned."
117
+ )
118
+ else:
119
+ raise ValueError(
120
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
121
+ "model. Make sure you have saved the model properly."
122
+ ) from e
123
+ except (UnicodeDecodeError, ValueError):
124
+ raise OSError(
125
+ f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
126
+ f"at '{checkpoint_file}'. "
127
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
128
+ )
129
+
130
+
131
+ def _load_state_dict_into_model(model_to_load, state_dict):
132
+ # Convert old format to new format if needed from a PyTorch state_dict
133
+ # copy state_dict so _load_from_state_dict can modify it
134
+ state_dict = state_dict.copy()
135
+ error_msgs = []
136
+
137
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
138
+ # so we need to apply the function recursively.
139
+ def load(module: torch.nn.Module, prefix=""):
140
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
141
+ module._load_from_state_dict(*args)
142
+
143
+ for name, child in module._modules.items():
144
+ if child is not None:
145
+ load(child, prefix + name + ".")
146
+
147
+ load(model_to_load)
148
+
149
+ return error_msgs
150
+
151
+
152
+ class ModelMixin(torch.nn.Module):
153
+ r"""
154
+ Base class for all models.
155
+
156
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
157
+ and saving models.
158
+
159
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
160
+ [`~models.ModelMixin.save_pretrained`].
161
+ """
162
+ config_name = CONFIG_NAME
163
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
164
+ _supports_gradient_checkpointing = False
165
+
166
+ def __init__(self):
167
+ super().__init__()
168
+
169
+ def __getattr__(self, name: str) -> Any:
170
+ """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
171
+ config attributes directly. See https://github.com/huggingface/diffusers/pull/3129 We need to overwrite
172
+ __getattr__ here in addition so that we don't trigger `torch.nn.Module`'s __getattr__':
173
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
174
+ """
175
+
176
+ is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
177
+ is_attribute = name in self.__dict__
178
+
179
+ if is_in_config and not is_attribute:
180
+ deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'unet.config.{name}'."
181
+ deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False, stacklevel=3)
182
+ return self._internal_dict[name]
183
+
184
+ # call PyTorch's https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
185
+ return super().__getattr__(name)
186
+
187
+ @property
188
+ def is_gradient_checkpointing(self) -> bool:
189
+ """
190
+ Whether gradient checkpointing is activated for this model or not.
191
+
192
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
193
+ activations".
194
+ """
195
+ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
196
+
197
+ def enable_gradient_checkpointing(self):
198
+ """
199
+ Activates gradient checkpointing for the current model.
200
+
201
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
202
+ activations".
203
+ """
204
+ if not self._supports_gradient_checkpointing:
205
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
206
+ self.apply(partial(self._set_gradient_checkpointing, value=True))
207
+
208
+ def disable_gradient_checkpointing(self):
209
+ """
210
+ Deactivates gradient checkpointing for the current model.
211
+
212
+ Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
213
+ activations".
214
+ """
215
+ if self._supports_gradient_checkpointing:
216
+ self.apply(partial(self._set_gradient_checkpointing, value=False))
217
+
218
+ def set_use_memory_efficient_attention_xformers(
219
+ self, valid: bool, attention_op: Optional[Callable] = None
220
+ ) -> None:
221
+ # Recursively walk through all the children.
222
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
223
+ # gets the message
224
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
225
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
226
+ module.set_use_memory_efficient_attention_xformers(valid, attention_op)
227
+
228
+ for child in module.children():
229
+ fn_recursive_set_mem_eff(child)
230
+
231
+ for module in self.children():
232
+ if isinstance(module, torch.nn.Module):
233
+ fn_recursive_set_mem_eff(module)
234
+
235
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
236
+ r"""
237
+ Enable memory efficient attention as implemented in xformers.
238
+
239
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
240
+ time. Speed up at training time is not guaranteed.
241
+
242
+ Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
243
+ is used.
244
+
245
+ Parameters:
246
+ attention_op (`Callable`, *optional*):
247
+ Override the default `None` operator for use as `op` argument to the
248
+ [`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
249
+ function of xFormers.
250
+
251
+ Examples:
252
+
253
+ ```py
254
+ >>> import torch
255
+ >>> from diffusers import UNet2DConditionModel
256
+ >>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
257
+
258
+ >>> model = UNet2DConditionModel.from_pretrained(
259
+ ... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
260
+ ... )
261
+ >>> model = model.to("cuda")
262
+ >>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
263
+ ```
264
+ """
265
+ self.set_use_memory_efficient_attention_xformers(True, attention_op)
266
+
267
+ def disable_xformers_memory_efficient_attention(self):
268
+ r"""
269
+ Disable memory efficient attention as implemented in xformers.
270
+ """
271
+ self.set_use_memory_efficient_attention_xformers(False)
272
+
273
+ def save_pretrained(
274
+ self,
275
+ save_directory: Union[str, os.PathLike],
276
+ is_main_process: bool = True,
277
+ save_function: Callable = None,
278
+ safe_serialization: bool = False,
279
+ variant: Optional[str] = None,
280
+ ):
281
+ """
282
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
283
+ `[`~models.ModelMixin.from_pretrained`]` class method.
284
+
285
+ Arguments:
286
+ save_directory (`str` or `os.PathLike`):
287
+ Directory to which to save. Will be created if it doesn't exist.
288
+ is_main_process (`bool`, *optional*, defaults to `True`):
289
+ Whether the process calling this is the main process or not. Useful when in distributed training like
290
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
291
+ the main process to avoid race conditions.
292
+ save_function (`Callable`):
293
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
294
+ need to replace `torch.save` by another method. Can be configured with the environment variable
295
+ `DIFFUSERS_SAVE_MODE`.
296
+ safe_serialization (`bool`, *optional*, defaults to `False`):
297
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
298
+ variant (`str`, *optional*):
299
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
300
+ """
301
+ if safe_serialization and not is_safetensors_available():
302
+ raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
303
+
304
+ if os.path.isfile(save_directory):
305
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
306
+ return
307
+
308
+ os.makedirs(save_directory, exist_ok=True)
309
+
310
+ model_to_save = self
311
+
312
+ # Attach architecture to the config
313
+ # Save the config
314
+ if is_main_process:
315
+ model_to_save.save_config(save_directory)
316
+
317
+ # Save the model
318
+ state_dict = model_to_save.state_dict()
319
+
320
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321
+ weights_name = _add_variant(weights_name, variant)
322
+
323
+ # Save the model
324
+ if safe_serialization:
325
+ safetensors.torch.save_file(
326
+ state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
327
+ )
328
+ else:
329
+ torch.save(state_dict, os.path.join(save_directory, weights_name))
330
+
331
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
332
+
333
+ @classmethod
334
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
335
+ r"""
336
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
337
+
338
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
339
+ the model, you should first set it back in training mode with `model.train()`.
340
+
341
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
342
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
343
+ task.
344
+
345
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
346
+ weights are discarded.
347
+
348
+ Parameters:
349
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
350
+ Can be either:
351
+
352
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
353
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
354
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
355
+ `./my_model_directory/`.
356
+
357
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
358
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
359
+ standard cache should not be used.
360
+ torch_dtype (`str` or `torch.dtype`, *optional*):
361
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
362
+ will be automatically derived from the model's weights.
363
+ force_download (`bool`, *optional*, defaults to `False`):
364
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
365
+ cached versions if they exist.
366
+ resume_download (`bool`, *optional*, defaults to `False`):
367
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
368
+ file exists.
369
+ proxies (`Dict[str, str]`, *optional*):
370
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
371
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
372
+ output_loading_info(`bool`, *optional*, defaults to `False`):
373
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
374
+ local_files_only(`bool`, *optional*, defaults to `False`):
375
+ Whether or not to only look at local files (i.e., do not try to download the model).
376
+ use_auth_token (`str` or *bool*, *optional*):
377
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
378
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
379
+ revision (`str`, *optional*, defaults to `"main"`):
380
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
381
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
382
+ identifier allowed by git.
383
+ from_flax (`bool`, *optional*, defaults to `False`):
384
+ Load the model weights from a Flax checkpoint save file.
385
+ subfolder (`str`, *optional*, defaults to `""`):
386
+ In case the relevant files are located inside a subfolder of the model repo (either remote in
387
+ huggingface.co or downloaded locally), you can specify the folder name here.
388
+
389
+ mirror (`str`, *optional*):
390
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
391
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
392
+ Please refer to the mirror site for more information.
393
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
394
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
395
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
396
+ same device.
397
+
398
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
399
+ more information about each option see [designing a device
400
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
401
+ max_memory (`Dict`, *optional*):
402
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
403
+ GPU and the available CPU RAM if unset.
404
+ offload_folder (`str` or `os.PathLike`, *optional*):
405
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
406
+ offload_state_dict (`bool`, *optional*):
407
+ If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
408
+ RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
409
+ `True` when there is some disk offload.
410
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
411
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
412
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
413
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
414
+ setting this argument to `True` will raise an error.
415
+ variant (`str`, *optional*):
416
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
417
+ ignored when using `from_flax`.
418
+ use_safetensors (`bool`, *optional*, defaults to `None`):
419
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
420
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
421
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
422
+
423
+ <Tip>
424
+
425
+ It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
426
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
427
+
428
+ </Tip>
429
+
430
+ <Tip>
431
+
432
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
433
+ this method in a firewalled environment.
434
+
435
+ </Tip>
436
+
437
+ """
438
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
439
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
440
+ force_download = kwargs.pop("force_download", False)
441
+ from_flax = kwargs.pop("from_flax", False)
442
+ resume_download = kwargs.pop("resume_download", False)
443
+ proxies = kwargs.pop("proxies", None)
444
+ output_loading_info = kwargs.pop("output_loading_info", False)
445
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
446
+ use_auth_token = kwargs.pop("use_auth_token", None)
447
+ revision = kwargs.pop("revision", None)
448
+ torch_dtype = kwargs.pop("torch_dtype", None)
449
+ subfolder = kwargs.pop("subfolder", None)
450
+ device_map = kwargs.pop("device_map", None)
451
+ max_memory = kwargs.pop("max_memory", None)
452
+ offload_folder = kwargs.pop("offload_folder", None)
453
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
454
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
455
+ variant = kwargs.pop("variant", None)
456
+ use_safetensors = kwargs.pop("use_safetensors", None)
457
+
458
+ if use_safetensors and not is_safetensors_available():
459
+ raise ValueError(
460
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
461
+ )
462
+
463
+ allow_pickle = False
464
+ if use_safetensors is None:
465
+ use_safetensors = is_safetensors_available()
466
+ allow_pickle = True
467
+
468
+ if low_cpu_mem_usage and not is_accelerate_available():
469
+ low_cpu_mem_usage = False
470
+ logger.warning(
471
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
472
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
473
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
474
+ " install accelerate\n```\n."
475
+ )
476
+
477
+ if device_map is not None and not is_accelerate_available():
478
+ raise NotImplementedError(
479
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
480
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
481
+ )
482
+
483
+ # Check if we can handle device_map and dispatching the weights
484
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
485
+ raise NotImplementedError(
486
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
487
+ " `device_map=None`."
488
+ )
489
+
490
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
491
+ raise NotImplementedError(
492
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
493
+ " `low_cpu_mem_usage=False`."
494
+ )
495
+
496
+ if low_cpu_mem_usage is False and device_map is not None:
497
+ raise ValueError(
498
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
499
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
500
+ )
501
+
502
+ # Load config if we don't provide a configuration
503
+ config_path = pretrained_model_name_or_path
504
+
505
+ user_agent = {
506
+ "diffusers": __version__,
507
+ "file_type": "model",
508
+ "framework": "pytorch",
509
+ }
510
+
511
+ # load config
512
+ config, unused_kwargs, commit_hash = cls.load_config(
513
+ config_path,
514
+ cache_dir=cache_dir,
515
+ return_unused_kwargs=True,
516
+ return_commit_hash=True,
517
+ force_download=force_download,
518
+ resume_download=resume_download,
519
+ proxies=proxies,
520
+ local_files_only=local_files_only,
521
+ use_auth_token=use_auth_token,
522
+ revision=revision,
523
+ subfolder=subfolder,
524
+ device_map=device_map,
525
+ max_memory=max_memory,
526
+ offload_folder=offload_folder,
527
+ offload_state_dict=offload_state_dict,
528
+ user_agent=user_agent,
529
+ **kwargs,
530
+ )
531
+
532
+ # load model
533
+ model_file = None
534
+ if from_flax:
535
+ model_file = _get_model_file(
536
+ pretrained_model_name_or_path,
537
+ weights_name=FLAX_WEIGHTS_NAME,
538
+ cache_dir=cache_dir,
539
+ force_download=force_download,
540
+ resume_download=resume_download,
541
+ proxies=proxies,
542
+ local_files_only=local_files_only,
543
+ use_auth_token=use_auth_token,
544
+ revision=revision,
545
+ subfolder=subfolder,
546
+ user_agent=user_agent,
547
+ commit_hash=commit_hash,
548
+ )
549
+ model = cls.from_config(config, **unused_kwargs)
550
+
551
+ # Convert the weights
552
+ from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
553
+
554
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
555
+ else:
556
+ if use_safetensors:
557
+ try:
558
+ model_file = _get_model_file(
559
+ pretrained_model_name_or_path,
560
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
561
+ cache_dir=cache_dir,
562
+ force_download=force_download,
563
+ resume_download=resume_download,
564
+ proxies=proxies,
565
+ local_files_only=local_files_only,
566
+ use_auth_token=use_auth_token,
567
+ revision=revision,
568
+ subfolder=subfolder,
569
+ user_agent=user_agent,
570
+ commit_hash=commit_hash,
571
+ )
572
+ except IOError as e:
573
+ if not allow_pickle:
574
+ raise e
575
+ pass
576
+ if model_file is None:
577
+ model_file = _get_model_file(
578
+ pretrained_model_name_or_path,
579
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
580
+ cache_dir=cache_dir,
581
+ force_download=force_download,
582
+ resume_download=resume_download,
583
+ proxies=proxies,
584
+ local_files_only=local_files_only,
585
+ use_auth_token=use_auth_token,
586
+ revision=revision,
587
+ subfolder=subfolder,
588
+ user_agent=user_agent,
589
+ commit_hash=commit_hash,
590
+ )
591
+
592
+ if low_cpu_mem_usage:
593
+ # Instantiate model with empty weights
594
+ with accelerate.init_empty_weights():
595
+ model = cls.from_config(config, **unused_kwargs)
596
+
597
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
598
+ if device_map is None:
599
+ param_device = "cpu"
600
+ state_dict = load_state_dict(model_file, variant=variant)
601
+ model._convert_deprecated_attention_blocks(state_dict)
602
+ # move the params from meta device to cpu
603
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
604
+ if len(missing_keys) > 0:
605
+ raise ValueError(
606
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
607
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
608
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
609
+ " those weights or else make sure your checkpoint file is correct."
610
+ )
611
+
612
+ empty_state_dict = model.state_dict()
613
+ for param_name, param in state_dict.items():
614
+ accepts_dtype = "dtype" in set(
615
+ inspect.signature(set_module_tensor_to_device).parameters.keys()
616
+ )
617
+
618
+ if empty_state_dict[param_name].shape != param.shape:
619
+ raise ValueError(
620
+ f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
621
+ )
622
+
623
+ if accepts_dtype:
624
+ set_module_tensor_to_device(
625
+ model, param_name, param_device, value=param, dtype=torch_dtype
626
+ )
627
+ else:
628
+ set_module_tensor_to_device(model, param_name, param_device, value=param)
629
+ else: # else let accelerate handle loading and dispatching.
630
+ # Load weights and dispatch according to the device_map
631
+ # by default the device_map is None and the weights are loaded on the CPU
632
+ accelerate.load_checkpoint_and_dispatch(
633
+ model,
634
+ model_file,
635
+ device_map,
636
+ max_memory=max_memory,
637
+ offload_folder=offload_folder,
638
+ offload_state_dict=offload_state_dict,
639
+ dtype=torch_dtype,
640
+ )
641
+
642
+ loading_info = {
643
+ "missing_keys": [],
644
+ "unexpected_keys": [],
645
+ "mismatched_keys": [],
646
+ "error_msgs": [],
647
+ }
648
+ else:
649
+ model = cls.from_config(config, **unused_kwargs)
650
+
651
+ state_dict = load_state_dict(model_file, variant=variant)
652
+ model._convert_deprecated_attention_blocks(state_dict)
653
+
654
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
655
+ model,
656
+ state_dict,
657
+ model_file,
658
+ pretrained_model_name_or_path,
659
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
660
+ )
661
+
662
+ loading_info = {
663
+ "missing_keys": missing_keys,
664
+ "unexpected_keys": unexpected_keys,
665
+ "mismatched_keys": mismatched_keys,
666
+ "error_msgs": error_msgs,
667
+ }
668
+
669
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
670
+ raise ValueError(
671
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
672
+ )
673
+ elif torch_dtype is not None:
674
+ model = model.to(torch_dtype)
675
+
676
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
677
+
678
+ # Set model in evaluation mode to deactivate DropOut modules by default
679
+ model.eval()
680
+ if output_loading_info:
681
+ return model, loading_info
682
+
683
+ return model
684
+
685
+ @classmethod
686
+ def _load_pretrained_model(
687
+ cls,
688
+ model,
689
+ state_dict,
690
+ resolved_archive_file,
691
+ pretrained_model_name_or_path,
692
+ ignore_mismatched_sizes=False,
693
+ ):
694
+ # Retrieve missing & unexpected_keys
695
+ model_state_dict = model.state_dict()
696
+ loaded_keys = list(state_dict.keys())
697
+
698
+ expected_keys = list(model_state_dict.keys())
699
+
700
+ original_loaded_keys = loaded_keys
701
+
702
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
703
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
704
+
705
+ # Make sure we are able to load base models as well as derived models (with heads)
706
+ model_to_load = model
707
+
708
+ def _find_mismatched_keys(
709
+ state_dict,
710
+ model_state_dict,
711
+ loaded_keys,
712
+ ignore_mismatched_sizes,
713
+ ):
714
+ mismatched_keys = []
715
+ if ignore_mismatched_sizes:
716
+ for checkpoint_key in loaded_keys:
717
+ model_key = checkpoint_key
718
+
719
+ if (
720
+ model_key in model_state_dict
721
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
722
+ ):
723
+ mismatched_keys.append(
724
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
725
+ )
726
+ del state_dict[checkpoint_key]
727
+ return mismatched_keys
728
+
729
+ if state_dict is not None:
730
+ # Whole checkpoint
731
+ mismatched_keys = _find_mismatched_keys(
732
+ state_dict,
733
+ model_state_dict,
734
+ original_loaded_keys,
735
+ ignore_mismatched_sizes,
736
+ )
737
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
738
+
739
+ if len(error_msgs) > 0:
740
+ error_msg = "\n\t".join(error_msgs)
741
+ if "size mismatch" in error_msg:
742
+ error_msg += (
743
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
744
+ )
745
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
746
+
747
+ if len(unexpected_keys) > 0:
748
+ logger.warning(
749
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
750
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
751
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
752
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
753
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
754
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
755
+ " identical (initializing a BertForSequenceClassification model from a"
756
+ " BertForSequenceClassification model)."
757
+ )
758
+ else:
759
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
760
+ if len(missing_keys) > 0:
761
+ logger.warning(
762
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
763
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
764
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
765
+ )
766
+ elif len(mismatched_keys) == 0:
767
+ logger.info(
768
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
769
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
770
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
771
+ " without further training."
772
+ )
773
+ if len(mismatched_keys) > 0:
774
+ mismatched_warning = "\n".join(
775
+ [
776
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
777
+ for key, shape1, shape2 in mismatched_keys
778
+ ]
779
+ )
780
+ logger.warning(
781
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
782
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
783
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
784
+ " able to use it for predictions and inference."
785
+ )
786
+
787
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
788
+
789
+ @property
790
+ def device(self) -> device:
791
+ """
792
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
793
+ device).
794
+ """
795
+ return get_parameter_device(self)
796
+
797
+ @property
798
+ def dtype(self) -> torch.dtype:
799
+ """
800
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
801
+ """
802
+ return get_parameter_dtype(self)
803
+
804
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
805
+ """
806
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
807
+
808
+ Args:
809
+ only_trainable (`bool`, *optional*, defaults to `False`):
810
+ Whether or not to return only the number of trainable parameters
811
+
812
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
813
+ Whether or not to return only the number of non-embeddings parameters
814
+
815
+ Returns:
816
+ `int`: The number of parameters.
817
+ """
818
+
819
+ if exclude_embeddings:
820
+ embedding_param_names = [
821
+ f"{name}.weight"
822
+ for name, module_type in self.named_modules()
823
+ if isinstance(module_type, torch.nn.Embedding)
824
+ ]
825
+ non_embedding_parameters = [
826
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
827
+ ]
828
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
829
+ else:
830
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
831
+
832
+ def _convert_deprecated_attention_blocks(self, state_dict):
833
+ deprecated_attention_block_paths = []
834
+
835
+ def recursive_find_attn_block(name, module):
836
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
837
+ deprecated_attention_block_paths.append(name)
838
+
839
+ for sub_name, sub_module in module.named_children():
840
+ sub_name = sub_name if name == "" else f"{name}.{sub_name}"
841
+ recursive_find_attn_block(sub_name, sub_module)
842
+
843
+ recursive_find_attn_block("", self)
844
+
845
+ # NOTE: we have to check if the deprecated parameters are in the state dict
846
+ # because it is possible we are loading from a state dict that was already
847
+ # converted
848
+
849
+ for path in deprecated_attention_block_paths:
850
+ # group_norm path stays the same
851
+
852
+ # query -> to_q
853
+ if f"{path}.query.weight" in state_dict:
854
+ state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
855
+ if f"{path}.query.bias" in state_dict:
856
+ state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
857
+
858
+ # key -> to_k
859
+ if f"{path}.key.weight" in state_dict:
860
+ state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
861
+ if f"{path}.key.bias" in state_dict:
862
+ state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
863
+
864
+ # value -> to_v
865
+ if f"{path}.value.weight" in state_dict:
866
+ state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
867
+ if f"{path}.value.bias" in state_dict:
868
+ state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
869
+
870
+ # proj_attn -> to_out.0
871
+ if f"{path}.proj_attn.weight" in state_dict:
872
+ state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
873
+ if f"{path}.proj_attn.bias" in state_dict:
874
+ state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
models/models.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPTextModel, CLIPTokenizer
3
+ from diffusers import AutoencoderKL, DDIMScheduler, DDIMInverseScheduler, DPMSolverMultistepScheduler
4
+ from .unet_2d_condition import UNet2DConditionModel
5
+ from easydict import EasyDict
6
+ import numpy as np
7
+ # For compatibility
8
+ from utils.latents import get_unscaled_latents, get_scaled_latents, blend_latents
9
+ from utils import torch_device
10
+
11
+ def load_sd(key="runwayml/stable-diffusion-v1-5", use_fp16=False, load_inverse_scheduler=True, use_dpm_multistep_scheduler=False):
12
+ """
13
+ Keys:
14
+ key = "CompVis/stable-diffusion-v1-4"
15
+ key = "runwayml/stable-diffusion-v1-5"
16
+ key = "stabilityai/stable-diffusion-2-1-base"
17
+
18
+ Unpack with:
19
+ ```
20
+ model_dict = load_sd(key=key, use_fp16=use_fp16)
21
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
22
+ ```
23
+
24
+ use_fp16: fp16 might have degraded performance
25
+ use_dpm_multistep_scheduler: DPMSolverMultistepScheduler
26
+ """
27
+
28
+ # run final results in fp32
29
+ if use_fp16:
30
+ dtype = torch.float16
31
+ revision = "fp16"
32
+ else:
33
+ dtype = torch.float
34
+ revision = "main"
35
+
36
+ vae = AutoencoderKL.from_pretrained(key, subfolder="vae", revision=revision, torch_dtype=dtype).to(torch_device)
37
+ tokenizer = CLIPTokenizer.from_pretrained(key, subfolder="tokenizer", revision=revision, torch_dtype=dtype)
38
+ text_encoder = CLIPTextModel.from_pretrained(key, subfolder="text_encoder", revision=revision, torch_dtype=dtype).to(torch_device)
39
+ unet = UNet2DConditionModel.from_pretrained(key, subfolder="unet", revision=revision, torch_dtype=dtype).to(torch_device)
40
+ if use_dpm_multistep_scheduler:
41
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
42
+ else:
43
+ scheduler = DDIMScheduler.from_pretrained(key, subfolder="scheduler", revision=revision, torch_dtype=dtype)
44
+
45
+ model_dict = EasyDict(vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, dtype=dtype)
46
+
47
+ if load_inverse_scheduler:
48
+ inverse_scheduler = DDIMInverseScheduler.from_config(scheduler.config)
49
+ model_dict.inverse_scheduler = inverse_scheduler
50
+
51
+ return model_dict
52
+
53
+ def encode_prompts(tokenizer, text_encoder, prompts, negative_prompt="", return_full_only=False, one_uncond_input_only=False):
54
+ if negative_prompt == "":
55
+ print("Note that negative_prompt is an empty string")
56
+
57
+ text_input = tokenizer(
58
+ prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
59
+ )
60
+
61
+ max_length = text_input.input_ids.shape[-1]
62
+ if one_uncond_input_only:
63
+ num_uncond_input = 1
64
+ else:
65
+ num_uncond_input = len(prompts)
66
+ uncond_input = tokenizer([negative_prompt] * num_uncond_input, padding="max_length", max_length=max_length, return_tensors="pt")
67
+
68
+ with torch.no_grad():
69
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
70
+ cond_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
71
+
72
+ if one_uncond_input_only:
73
+ return uncond_embeddings, cond_embeddings
74
+
75
+ text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
76
+
77
+ if return_full_only:
78
+ return text_embeddings
79
+ return text_embeddings, uncond_embeddings, cond_embeddings
80
+
81
+ def attn_list_to_tensor(cross_attention_probs):
82
+ # timestep, CrossAttnBlock, Transformer2DModel, 1xBasicTransformerBlock
83
+
84
+ num_cross_attn_block = len(cross_attention_probs[0])
85
+ cross_attention_probs_all = []
86
+
87
+ for i in range(num_cross_attn_block):
88
+ # cross_attention_probs_timestep[i]: Transformer2DModel
89
+ # 1xBasicTransformerBlock is skipped
90
+ cross_attention_probs_current = []
91
+ for cross_attention_probs_timestep in cross_attention_probs:
92
+ cross_attention_probs_current.append(torch.stack([item for item in cross_attention_probs_timestep[i]], dim=0))
93
+
94
+ cross_attention_probs_current = torch.stack(cross_attention_probs_current, dim=0)
95
+ cross_attention_probs_all.append(cross_attention_probs_current)
96
+
97
+ return cross_attention_probs_all
models/pipelines.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ import utils
4
+ from PIL import Image
5
+ import gc
6
+ import numpy as np
7
+ from .attention import GatedSelfAttentionDense
8
+ from .models import torch_device
9
+
10
+ @torch.no_grad()
11
+ def encode(model_dict, image, generator):
12
+ """
13
+ image should be a PIL object or numpy array with range 0 to 255
14
+ """
15
+
16
+ vae, dtype = model_dict.vae, model_dict.dtype
17
+
18
+ if isinstance(image, Image.Image):
19
+ w, h = image.size
20
+ assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8"
21
+ # w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
22
+ # image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
23
+ image = np.array(image)
24
+
25
+ if isinstance(image, np.ndarray):
26
+ assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})"
27
+ image = image.astype(np.float32) / 255.0
28
+ image = image[None, ...]
29
+ image = image.transpose(0, 3, 1, 2)
30
+ image = 2.0 * image - 1.0
31
+ image = torch.from_numpy(image)
32
+
33
+ assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
34
+
35
+ image = image.to(device=torch_device, dtype=dtype)
36
+ latents = vae.encode(image).latent_dist.sample(generator)
37
+
38
+ latents = vae.config.scaling_factor * latents
39
+
40
+ return latents
41
+
42
+ @torch.no_grad()
43
+ def decode(vae, latents):
44
+ # scale and decode the image latents with vae
45
+ scaled_latents = 1 / 0.18215 * latents
46
+ with torch.no_grad():
47
+ image = vae.decode(scaled_latents).sample
48
+
49
+ image = (image / 2 + 0.5).clamp(0, 1)
50
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
51
+ images = (image * 255).round().astype("uint8")
52
+
53
+ return images
54
+
55
+ @torch.no_grad()
56
+ def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False):
57
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
58
+ text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
59
+
60
+ if not no_set_timesteps:
61
+ scheduler.set_timesteps(num_inference_steps)
62
+
63
+ for t in tqdm(scheduler.timesteps):
64
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
65
+ latent_model_input = torch.cat([latents] * 2)
66
+
67
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
68
+
69
+ # predict the noise residual
70
+ with torch.no_grad():
71
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
72
+
73
+ # perform guidance
74
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
75
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
76
+
77
+ # compute the previous noisy sample x_t -> x_t-1
78
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
79
+
80
+ images = decode(vae, latents)
81
+
82
+ ret = [latents, images]
83
+
84
+ return tuple(ret)
85
+
86
+ def gligen_enable_fuser(unet, enabled=True):
87
+ for module in unet.modules():
88
+ if isinstance(module, GatedSelfAttentionDense):
89
+ module.enabled = enabled
90
+
91
+ @torch.no_grad()
92
+ def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
93
+ frozen_steps=20, frozen_mask=None,
94
+ return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
95
+ offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
96
+ semantic_guidance=False, semantic_guidance_bboxes=None, semantic_guidance_object_positions=None, semantic_guidance_kwargs=None,
97
+ return_box_vis=False, show_progress=True, save_all_latents=False):
98
+ """
99
+ The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
100
+ """
101
+ vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict.scheduler, model_dict.dtype
102
+ text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
103
+
104
+ if latents.dim() == 5:
105
+ # latents_all from the input side, different from the latents_all to be saved
106
+ latents_all_input = latents
107
+ latents = latents[0]
108
+ else:
109
+ latents_all_input = None
110
+
111
+ # Just in case that we have in-place ops
112
+ latents = latents.clone()
113
+
114
+ if save_all_latents:
115
+ # offload to cpu to save space
116
+ if offload_latents_to_cpu:
117
+ latents_all = [latents.cpu()]
118
+ else:
119
+ latents_all = [latents]
120
+
121
+ scheduler.set_timesteps(num_inference_steps)
122
+
123
+ if frozen_mask is not None:
124
+ frozen_mask = frozen_mask.to(dtype=dtype).clamp(0., 1.)
125
+
126
+ batch_size = 1
127
+
128
+ # 5.1 Prepare GLIGEN variables
129
+ assert len(phrases) == len(bboxes)
130
+ # assert batch_size == 1
131
+ max_objs = 30
132
+ _boxes = bboxes
133
+
134
+ n_objs = min(len(_boxes), max_objs)
135
+ boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype)
136
+ phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype)
137
+ masks = torch.zeros(max_objs, device=torch_device, dtype=dtype)
138
+
139
+ if n_objs > 0:
140
+ boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
141
+ tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device)
142
+ _phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
143
+ phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs]
144
+ masks[:n_objs] = 1
145
+
146
+ # Classifier-free guidance
147
+ repeat_batch = batch_size * num_images_per_prompt * 2
148
+
149
+ boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
150
+ phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
151
+ masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
152
+ masks[:repeat_batch // 2] = 0
153
+
154
+ if semantic_guidance_bboxes and semantic_guidance:
155
+ loss = torch.tensor(10000.)
156
+ # TODO: we can also save necessary tokens only to save memory.
157
+ # offload_guidance_cross_attn_to_cpu does not save too much since we only store attention map for each timestep.
158
+ guidance_cross_attention_kwargs = {
159
+ 'offload_cross_attn_to_cpu': False,
160
+ 'enable_flash_attn': False,
161
+ 'gligen': {
162
+ 'boxes': boxes[:repeat_batch // 2],
163
+ 'positive_embeddings': phrase_embeddings[:repeat_batch // 2],
164
+ 'masks': masks[:repeat_batch // 2],
165
+ 'fuser_attn_kwargs': {
166
+ 'enable_flash_attn': False,
167
+ }
168
+ }
169
+ }
170
+
171
+ if return_saved_cross_attn:
172
+ saved_attns = []
173
+
174
+ main_cross_attention_kwargs = {
175
+ 'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
176
+ 'return_cond_ca_only': return_cond_ca_only,
177
+ 'return_token_ca_only': return_token_ca_only,
178
+ 'save_keys': saved_cross_attn_keys,
179
+ 'gligen': {
180
+ 'boxes': boxes,
181
+ 'positive_embeddings': phrase_embeddings,
182
+ 'masks': masks
183
+ }
184
+ }
185
+
186
+ timesteps = scheduler.timesteps
187
+
188
+ num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
189
+ gligen_enable_fuser(unet, True)
190
+
191
+ for index, t in enumerate(tqdm(timesteps, disable=not show_progress)):
192
+ # Scheduled sampling
193
+ if index == num_grounding_steps:
194
+ gligen_enable_fuser(unet, False)
195
+
196
+ if semantic_guidance_bboxes and semantic_guidance:
197
+ with torch.enable_grad():
198
+ latents, loss = latent_backward_guidance(scheduler, unet, cond_embeddings, index, semantic_guidance_bboxes, semantic_guidance_object_positions, t, latents, loss, cross_attention_kwargs=guidance_cross_attention_kwargs, **semantic_guidance_kwargs)
199
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
200
+ latent_model_input = torch.cat([latents] * 2)
201
+
202
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
203
+
204
+ main_cross_attention_kwargs['save_attn_to_dict'] = {}
205
+
206
+ # predict the noise residual
207
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
208
+ cross_attention_kwargs=main_cross_attention_kwargs).sample
209
+
210
+ if return_saved_cross_attn:
211
+ saved_attns.append(main_cross_attention_kwargs['save_attn_to_dict'])
212
+
213
+ del main_cross_attention_kwargs['save_attn_to_dict']
214
+
215
+ # perform guidance
216
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
217
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
218
+
219
+ # compute the previous noisy sample x_t -> x_t-1
220
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
221
+
222
+ if frozen_mask is not None and index < frozen_steps:
223
+ latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
224
+
225
+ if save_all_latents:
226
+ if offload_latents_to_cpu:
227
+ latents_all.append(latents.cpu())
228
+ else:
229
+ latents_all.append(latents)
230
+
231
+ # Turn off fuser for typical SD
232
+ gligen_enable_fuser(unet, False)
233
+ images = decode(vae, latents)
234
+
235
+ ret = [latents, images]
236
+ if return_saved_cross_attn:
237
+ ret.append(saved_attns)
238
+ if return_box_vis:
239
+ pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
240
+ ret.append(pil_images)
241
+ if save_all_latents:
242
+ latents_all = torch.stack(latents_all, dim=0)
243
+ ret.append(latents_all)
244
+
245
+ return tuple(ret)
246
+
models/sam.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from models import torch_device
6
+ from transformers import SamModel, SamProcessor
7
+ import utils
8
+ import cv2
9
+ from scipy import ndimage
10
+
11
+ def load_sam():
12
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(torch_device)
13
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
+
15
+ sam_model_dict = dict(
16
+ sam_model = sam_model, sam_processor = sam_processor
17
+ )
18
+
19
+ return sam_model_dict
20
+
21
+ # Not fully backward compatible with the previous implementation
22
+ # Reference: lmdv2/notebooks/gen_masked_latents_multi_object_ref_ca_loss_modular.ipynb
23
+ def sam(sam_model_dict, image, input_points=None, input_boxes=None, target_mask_shape=None):
24
+ """target_mask_shape: (h, w)"""
25
+ sam_model, sam_processor = sam_model_dict['sam_model'], sam_model_dict['sam_processor']
26
+
27
+ with torch.no_grad():
28
+ with torch.autocast(torch_device):
29
+ inputs = sam_processor(image, input_points=input_points, input_boxes=input_boxes, return_tensors="pt").to(torch_device)
30
+ outputs = sam_model(**inputs)
31
+ masks = sam_processor.image_processor.post_process_masks(
32
+ outputs.pred_masks.cpu().float(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
33
+ )
34
+ conf_scores = outputs.iou_scores.to(device="cpu", dtype=torch.float32).numpy()[0,0]
35
+ del inputs, outputs
36
+
37
+ gc.collect()
38
+ if torch_device == "cuda":
39
+ torch.cuda.empty_cache()
40
+
41
+ masks = masks[0][0].numpy()
42
+
43
+ if target_mask_shape is not None:
44
+ masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, target_mask_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
45
+
46
+ return masks, conf_scores
47
+
48
+ def sam_point_input(sam_model_dict, image, input_points, **kwargs):
49
+ return sam(sam_model_dict, image, input_points=input_points, **kwargs)
50
+
51
+ def sam_box_input(sam_model_dict, image, input_boxes, **kwargs):
52
+ return sam(sam_model_dict, image, input_boxes=input_boxes, **kwargs)
53
+
54
+ def get_iou_with_resize(mask, masks, masks_shape):
55
+ masks = np.array([cv2.resize(mask.astype(np.uint8) * 255, masks_shape[::-1], cv2.INTER_LINEAR).astype(bool) for mask in masks])
56
+ return utils.iou(mask, masks)
57
+
58
+ def select_mask(masks, conf_scores, coarse_ious=None, rule="largest_over_conf", discourage_mask_below_confidence=0.85, discourage_mask_below_coarse_iou=0.2, verbose=False):
59
+ """masks: numpy bool array"""
60
+ mask_sizes = masks.sum(axis=(1, 2))
61
+
62
+ # Another possible rule: iou with the attention mask
63
+ if rule == "largest_over_conf":
64
+ # Use the largest segmentation
65
+ # Discourage selecting masks with conf too low or coarse iou is too low
66
+ max_mask_size = np.max(mask_sizes)
67
+ if coarse_ious is not None:
68
+ scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size - (coarse_ious < discourage_mask_below_coarse_iou) * max_mask_size
69
+ else:
70
+ scores = mask_sizes - (conf_scores < discourage_mask_below_confidence) * max_mask_size
71
+ if verbose:
72
+ print(f"mask_sizes: {mask_sizes}, scores: {scores}")
73
+ else:
74
+ raise ValueError(f"Unknown rule: {rule}")
75
+
76
+ mask_id = np.argmax(scores)
77
+ mask = masks[mask_id]
78
+
79
+ selection_conf = conf_scores[mask_id]
80
+
81
+ if coarse_ious is not None:
82
+ selection_coarse_iou = coarse_ious[mask_id]
83
+ else:
84
+ selection_coarse_iou = None
85
+
86
+ if verbose:
87
+ # print(f"Confidences: {conf_scores}")
88
+ print(f"Selected a mask with confidence: {selection_conf}, coarse_iou: {selection_coarse_iou}")
89
+
90
+ if verbose:
91
+ plt.figure(figsize=(10, 8))
92
+ # plt.suptitle("After SAM")
93
+ for ind in range(3):
94
+ plt.subplot(1, 3, ind+1)
95
+ # This is obtained before resize.
96
+ plt.title(f"Mask {ind}, score {scores[ind]}, conf {conf_scores[ind]:.2f}, iou {coarse_ious[ind] if coarse_ious is not None else None:.2f}")
97
+ plt.imshow(masks[ind])
98
+ plt.tight_layout()
99
+ plt.show()
100
+
101
+ return mask, selection_conf
102
+
103
+ def preprocess_mask(token_attn_np_smooth, mask_th, n_erode_dilate_mask=0):
104
+ token_attn_np_smooth_normalized = token_attn_np_smooth - token_attn_np_smooth.min()
105
+ token_attn_np_smooth_normalized /= token_attn_np_smooth_normalized.max()
106
+ mask_thresholded = token_attn_np_smooth_normalized > mask_th
107
+
108
+ if n_erode_dilate_mask:
109
+ mask_thresholded = ndimage.binary_erosion(mask_thresholded, iterations=n_erode_dilate_mask)
110
+ mask_thresholded = ndimage.binary_dilation(mask_thresholded, iterations=n_erode_dilate_mask)
111
+
112
+ return mask_thresholded
113
+
114
+ # The overall pipeline to refine the attention mask
115
+ def sam_refine_attn(sam_input_image, token_attn_np, model_dict, height, width, H, W, use_box_input, gaussian_sigma, mask_th_for_box, n_erode_dilate_mask_for_box, mask_th_for_point, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
116
+
117
+ # token_attn_np is for visualizations
118
+ token_attn_np_smooth = ndimage.gaussian_filter(token_attn_np, sigma=gaussian_sigma)
119
+
120
+ # (w, h)
121
+ mask_size_scale = height // token_attn_np_smooth.shape[1], width // token_attn_np_smooth.shape[0]
122
+
123
+ if use_box_input:
124
+ # box input
125
+ mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_box, n_erode_dilate_mask=n_erode_dilate_mask_for_box)
126
+
127
+ input_boxes = utils.binary_mask_to_box(mask_binary, w_scale=mask_size_scale[0], h_scale=mask_size_scale[1])
128
+ input_boxes = [input_boxes]
129
+
130
+ masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
131
+ else:
132
+ # point input
133
+ mask_binary = preprocess_mask(token_attn_np_smooth, mask_th_for_point, n_erode_dilate_mask=0)
134
+
135
+ # Uses the max coordinate only
136
+ max_coord = np.unravel_index(token_attn_np_smooth.argmax(), token_attn_np_smooth.shape)
137
+ # print("max_coord:", max_coord)
138
+ input_points = [[[max_coord[1] * mask_size_scale[1], max_coord[0] * mask_size_scale[0]]]]
139
+
140
+ masks, conf_scores = sam_point_input(model_dict, image=sam_input_image, input_points=input_points, target_mask_shape=(H, W))
141
+
142
+ if verbose:
143
+ plt.title("Coarse binary mask (for box for box input and for iou)")
144
+ plt.imshow(mask_binary)
145
+ plt.show()
146
+
147
+ coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
148
+
149
+ mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
150
+ rule="largest_over_conf",
151
+ discourage_mask_below_confidence=discourage_mask_below_confidence,
152
+ discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
153
+ verbose=True)
154
+
155
+ return mask_selected, conf_score_selected
156
+
157
+ def sam_refine_box(sam_input_image, box, model_dict, height, width, H, W, discourage_mask_below_confidence, discourage_mask_below_coarse_iou, verbose):
158
+ # (w, h)
159
+ input_boxes = utils.scale_proportion(box, H=height, W=width)
160
+ input_boxes = [input_boxes]
161
+
162
+ masks, conf_scores = sam_box_input(model_dict, image=sam_input_image, input_boxes=input_boxes, target_mask_shape=(H, W))
163
+
164
+ mask_binary = utils.proportion_to_mask(box, H, W, return_np=True)
165
+ if verbose:
166
+ # Also the box is the input for SAM
167
+ plt.title("Binary mask from input box (for iou)")
168
+ plt.imshow(mask_binary)
169
+ plt.show()
170
+
171
+ coarse_ious = get_iou_with_resize(mask_binary, masks, masks_shape=mask_binary.shape)
172
+
173
+ mask_selected, conf_score_selected = select_mask(masks, conf_scores, coarse_ious=coarse_ious,
174
+ rule="largest_over_conf",
175
+ discourage_mask_below_confidence=discourage_mask_below_confidence,
176
+ discourage_mask_below_coarse_iou=discourage_mask_below_coarse_iou,
177
+ verbose=True)
178
+
179
+ return mask_selected, conf_score_selected
models/transformer_2d.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate
24
+ from .attention import BasicTransformerBlock
25
+ from diffusers.models.embeddings import PatchEmbed
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+
28
+
29
+ @dataclass
30
+ class Transformer2DModelOutput(BaseOutput):
31
+ """
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
34
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
35
+ for the unnoised latent pixels.
36
+ """
37
+
38
+ sample: torch.FloatTensor
39
+
40
+
41
+ class Transformer2DModel(ModelMixin, ConfigMixin):
42
+ """
43
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
44
+ embeddings) inputs.
45
+
46
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
47
+ transformer action. Finally, reshape to image.
48
+
49
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
50
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
51
+ classes of unnoised image.
52
+
53
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
54
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
55
+
56
+ Parameters:
57
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
58
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
59
+ in_channels (`int`, *optional*):
60
+ Pass if the input is continuous. The number of channels in the input and output.
61
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
62
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
63
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
64
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
65
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
66
+ `ImagePositionalEmbeddings`.
67
+ num_vector_embeds (`int`, *optional*):
68
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
69
+ Includes the class for the masked latent pixel.
70
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
71
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
72
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
73
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
74
+ up to but not more than steps than `num_embeds_ada_norm`.
75
+ attention_bias (`bool`, *optional*):
76
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
77
+ """
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ num_attention_heads: int = 16,
83
+ attention_head_dim: int = 88,
84
+ in_channels: Optional[int] = None,
85
+ out_channels: Optional[int] = None,
86
+ num_layers: int = 1,
87
+ dropout: float = 0.0,
88
+ norm_num_groups: int = 32,
89
+ cross_attention_dim: Optional[int] = None,
90
+ attention_bias: bool = False,
91
+ sample_size: Optional[int] = None,
92
+ num_vector_embeds: Optional[int] = None,
93
+ patch_size: Optional[int] = None,
94
+ activation_fn: str = "geglu",
95
+ num_embeds_ada_norm: Optional[int] = None,
96
+ use_linear_projection: bool = False,
97
+ only_cross_attention: bool = False,
98
+ upcast_attention: bool = False,
99
+ norm_type: str = "layer_norm",
100
+ norm_elementwise_affine: bool = True,
101
+ use_gated_attention: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.use_linear_projection = use_linear_projection
105
+ self.num_attention_heads = num_attention_heads
106
+ self.attention_head_dim = attention_head_dim
107
+ inner_dim = num_attention_heads * attention_head_dim
108
+
109
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
110
+ # Define whether input is continuous or discrete depending on configuration
111
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
112
+ self.is_input_vectorized = num_vector_embeds is not None
113
+ self.is_input_patches = in_channels is not None and patch_size is not None
114
+
115
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
116
+ deprecation_message = (
117
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
118
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
119
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
120
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
121
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
122
+ )
123
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
124
+ norm_type = "ada_norm"
125
+
126
+ if self.is_input_continuous and self.is_input_vectorized:
127
+ raise ValueError(
128
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
129
+ " sure that either `in_channels` or `num_vector_embeds` is None."
130
+ )
131
+ elif self.is_input_vectorized and self.is_input_patches:
132
+ raise ValueError(
133
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
134
+ " sure that either `num_vector_embeds` or `num_patches` is None."
135
+ )
136
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
137
+ raise ValueError(
138
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
139
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
140
+ )
141
+
142
+ # 2. Define input layers
143
+ if self.is_input_continuous:
144
+ self.in_channels = in_channels
145
+
146
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
147
+ if use_linear_projection:
148
+ self.proj_in = nn.Linear(in_channels, inner_dim)
149
+ else:
150
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
151
+ elif self.is_input_vectorized:
152
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
153
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
154
+
155
+ self.height = sample_size
156
+ self.width = sample_size
157
+ self.num_vector_embeds = num_vector_embeds
158
+ self.num_latent_pixels = self.height * self.width
159
+
160
+ self.latent_image_embedding = ImagePositionalEmbeddings(
161
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
162
+ )
163
+ elif self.is_input_patches:
164
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
165
+
166
+ self.height = sample_size
167
+ self.width = sample_size
168
+
169
+ self.patch_size = patch_size
170
+ self.pos_embed = PatchEmbed(
171
+ height=sample_size,
172
+ width=sample_size,
173
+ patch_size=patch_size,
174
+ in_channels=in_channels,
175
+ embed_dim=inner_dim,
176
+ )
177
+
178
+ # 3. Define transformers blocks
179
+ self.transformer_blocks = nn.ModuleList(
180
+ [
181
+ BasicTransformerBlock(
182
+ inner_dim,
183
+ num_attention_heads,
184
+ attention_head_dim,
185
+ dropout=dropout,
186
+ cross_attention_dim=cross_attention_dim,
187
+ activation_fn=activation_fn,
188
+ num_embeds_ada_norm=num_embeds_ada_norm,
189
+ attention_bias=attention_bias,
190
+ only_cross_attention=only_cross_attention,
191
+ upcast_attention=upcast_attention,
192
+ norm_type=norm_type,
193
+ norm_elementwise_affine=norm_elementwise_affine,
194
+ use_gated_attention=use_gated_attention,
195
+ )
196
+ for d in range(num_layers)
197
+ ]
198
+ )
199
+
200
+ # 4. Define output layers
201
+ self.out_channels = in_channels if out_channels is None else out_channels
202
+ if self.is_input_continuous:
203
+ # TODO: should use out_channels for continuous projections
204
+ if use_linear_projection:
205
+ self.proj_out = nn.Linear(inner_dim, in_channels)
206
+ else:
207
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
208
+ elif self.is_input_vectorized:
209
+ self.norm_out = nn.LayerNorm(inner_dim)
210
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
211
+ elif self.is_input_patches:
212
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
213
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
214
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
215
+
216
+ def forward(
217
+ self,
218
+ hidden_states: torch.Tensor,
219
+ encoder_hidden_states: Optional[torch.Tensor] = None,
220
+ timestep: Optional[torch.LongTensor] = None,
221
+ class_labels: Optional[torch.LongTensor] = None,
222
+ cross_attention_kwargs: Dict[str, Any] = None,
223
+ attention_mask: Optional[torch.Tensor] = None,
224
+ encoder_attention_mask: Optional[torch.Tensor] = None,
225
+ return_dict: bool = True,
226
+ return_cross_attention_probs: bool = False,
227
+ ):
228
+ """
229
+ Args:
230
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
231
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
232
+ hidden_states
233
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
234
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
235
+ self-attention.
236
+ timestep ( `torch.LongTensor`, *optional*):
237
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
238
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
239
+ Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
240
+ conditioning.
241
+ encoder_attention_mask ( `torch.Tensor`, *optional* ).
242
+ Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
243
+ Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
244
+ = keep, -10000 = discard.
245
+ If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
246
+ above. This bias will be added to the cross-attention scores.
247
+ return_dict (`bool`, *optional*, defaults to `True`):
248
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
249
+
250
+ Returns:
251
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
252
+ [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
253
+ returning a tuple, the first element is the sample tensor.
254
+ """
255
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
256
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
257
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
258
+ # expects mask of shape:
259
+ # [batch, key_tokens]
260
+ # adds singleton query_tokens dimension:
261
+ # [batch, 1, key_tokens]
262
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
263
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
264
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
265
+ if attention_mask is not None and attention_mask.ndim == 2:
266
+ # assume that mask is expressed as:
267
+ # (1 = keep, 0 = discard)
268
+ # convert mask into a bias that can be added to attention scores:
269
+ # (keep = +0, discard = -10000.0)
270
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
271
+ attention_mask = attention_mask.unsqueeze(1)
272
+
273
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
274
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
275
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
276
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
277
+
278
+ # 1. Input
279
+ if self.is_input_continuous:
280
+ batch, _, height, width = hidden_states.shape
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.norm(hidden_states)
284
+ if not self.use_linear_projection:
285
+ hidden_states = self.proj_in(hidden_states)
286
+ inner_dim = hidden_states.shape[1]
287
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
288
+ else:
289
+ inner_dim = hidden_states.shape[1]
290
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
291
+ hidden_states = self.proj_in(hidden_states)
292
+ elif self.is_input_vectorized:
293
+ hidden_states = self.latent_image_embedding(hidden_states)
294
+ elif self.is_input_patches:
295
+ hidden_states = self.pos_embed(hidden_states)
296
+
297
+ base_attn_key = cross_attention_kwargs["attn_key"]
298
+
299
+ # 2. Blocks
300
+ cross_attention_probs_all = []
301
+ for block_ind, block in enumerate(self.transformer_blocks):
302
+ cross_attention_kwargs["attn_key"] = base_attn_key + [block_ind]
303
+
304
+ hidden_states = block(
305
+ hidden_states,
306
+ attention_mask=attention_mask,
307
+ encoder_hidden_states=encoder_hidden_states,
308
+ encoder_attention_mask=encoder_attention_mask,
309
+ timestep=timestep,
310
+ cross_attention_kwargs=cross_attention_kwargs,
311
+ class_labels=class_labels,
312
+ return_cross_attention_probs=return_cross_attention_probs,
313
+ )
314
+ if return_cross_attention_probs:
315
+ hidden_states, cross_attention_probs = hidden_states
316
+ cross_attention_probs_all.append(cross_attention_probs)
317
+
318
+ # 3. Output
319
+ if self.is_input_continuous:
320
+ if not self.use_linear_projection:
321
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
322
+ hidden_states = self.proj_out(hidden_states)
323
+ else:
324
+ hidden_states = self.proj_out(hidden_states)
325
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
326
+
327
+ output = hidden_states + residual
328
+ elif self.is_input_vectorized:
329
+ hidden_states = self.norm_out(hidden_states)
330
+ logits = self.out(hidden_states)
331
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
332
+ logits = logits.permute(0, 2, 1)
333
+
334
+ # log(p(x_0))
335
+ output = F.log_softmax(logits.double(), dim=1).float()
336
+ elif self.is_input_patches:
337
+ # TODO: cleanup!
338
+ conditioning = self.transformer_blocks[0].norm1.emb(
339
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
340
+ )
341
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
342
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
343
+ hidden_states = self.proj_out_2(hidden_states)
344
+
345
+ # unpatchify
346
+ height = width = int(hidden_states.shape[1] ** 0.5)
347
+ hidden_states = hidden_states.reshape(
348
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
349
+ )
350
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
351
+ output = hidden_states.reshape(
352
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
353
+ )
354
+
355
+ if len(cross_attention_probs_all) == 1:
356
+ # If we only have one transformer block in a Transformer2DModel, we do not create another nested level.
357
+ cross_attention_probs_all = cross_attention_probs_all[0]
358
+
359
+ if not return_dict:
360
+ if return_cross_attention_probs:
361
+ return (output, cross_attention_probs_all)
362
+ return (output,)
363
+
364
+ output = Transformer2DModelOutput(sample=output)
365
+ if return_cross_attention_probs:
366
+ return output, cross_attention_probs_all
367
+ return output
models/unet_2d_blocks.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version
22
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
23
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
24
+ from .transformer_2d import Transformer2DModel
25
+
26
+
27
+ def get_down_block(
28
+ down_block_type,
29
+ num_layers,
30
+ in_channels,
31
+ out_channels,
32
+ temb_channels,
33
+ add_downsample,
34
+ resnet_eps,
35
+ resnet_act_fn,
36
+ attn_num_head_channels,
37
+ resnet_groups=None,
38
+ cross_attention_dim=None,
39
+ downsample_padding=None,
40
+ dual_cross_attention=False,
41
+ use_linear_projection=False,
42
+ only_cross_attention=False,
43
+ upcast_attention=False,
44
+ resnet_time_scale_shift="default",
45
+ resnet_skip_time_act=False,
46
+ resnet_out_scale_factor=1.0,
47
+ cross_attention_norm=None,
48
+ use_gated_attention=False,
49
+ ):
50
+ down_block_type = down_block_type[7:] if down_block_type.startswith(
51
+ "UNetRes") else down_block_type
52
+ if down_block_type == "DownBlock2D":
53
+ return DownBlock2D(
54
+ num_layers=num_layers,
55
+ in_channels=in_channels,
56
+ out_channels=out_channels,
57
+ temb_channels=temb_channels,
58
+ add_downsample=add_downsample,
59
+ resnet_eps=resnet_eps,
60
+ resnet_act_fn=resnet_act_fn,
61
+ resnet_groups=resnet_groups,
62
+ downsample_padding=downsample_padding,
63
+ resnet_time_scale_shift=resnet_time_scale_shift,
64
+ )
65
+ elif down_block_type == "CrossAttnDownBlock2D":
66
+ if cross_attention_dim is None:
67
+ raise ValueError(
68
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D")
69
+ return CrossAttnDownBlock2D(
70
+ num_layers=num_layers,
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ temb_channels=temb_channels,
74
+ add_downsample=add_downsample,
75
+ resnet_eps=resnet_eps,
76
+ resnet_act_fn=resnet_act_fn,
77
+ resnet_groups=resnet_groups,
78
+ downsample_padding=downsample_padding,
79
+ cross_attention_dim=cross_attention_dim,
80
+ attn_num_head_channels=attn_num_head_channels,
81
+ dual_cross_attention=dual_cross_attention,
82
+ use_linear_projection=use_linear_projection,
83
+ only_cross_attention=only_cross_attention,
84
+ upcast_attention=upcast_attention,
85
+ resnet_time_scale_shift=resnet_time_scale_shift,
86
+ use_gated_attention=use_gated_attention,
87
+ )
88
+
89
+ raise ValueError(f"{down_block_type} does not exist.")
90
+
91
+
92
+ def get_up_block(
93
+ up_block_type,
94
+ num_layers,
95
+ in_channels,
96
+ out_channels,
97
+ prev_output_channel,
98
+ temb_channels,
99
+ add_upsample,
100
+ resnet_eps,
101
+ resnet_act_fn,
102
+ attn_num_head_channels,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+ resnet_skip_time_act=False,
111
+ resnet_out_scale_factor=1.0,
112
+ cross_attention_norm=None,
113
+ use_gated_attention=False,
114
+ ):
115
+ up_block_type = up_block_type[7:] if up_block_type.startswith(
116
+ "UNetRes") else up_block_type
117
+ if up_block_type == "UpBlock2D":
118
+ return UpBlock2D(
119
+ num_layers=num_layers,
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ prev_output_channel=prev_output_channel,
123
+ temb_channels=temb_channels,
124
+ add_upsample=add_upsample,
125
+ resnet_eps=resnet_eps,
126
+ resnet_act_fn=resnet_act_fn,
127
+ resnet_groups=resnet_groups,
128
+ resnet_time_scale_shift=resnet_time_scale_shift,
129
+ )
130
+ elif up_block_type == "CrossAttnUpBlock2D":
131
+ if cross_attention_dim is None:
132
+ raise ValueError(
133
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D")
134
+ return CrossAttnUpBlock2D(
135
+ num_layers=num_layers,
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ prev_output_channel=prev_output_channel,
139
+ temb_channels=temb_channels,
140
+ add_upsample=add_upsample,
141
+ resnet_eps=resnet_eps,
142
+ resnet_act_fn=resnet_act_fn,
143
+ resnet_groups=resnet_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ attn_num_head_channels=attn_num_head_channels,
146
+ dual_cross_attention=dual_cross_attention,
147
+ use_linear_projection=use_linear_projection,
148
+ only_cross_attention=only_cross_attention,
149
+ upcast_attention=upcast_attention,
150
+ resnet_time_scale_shift=resnet_time_scale_shift,
151
+ use_gated_attention=use_gated_attention,
152
+ )
153
+
154
+ raise ValueError(f"{up_block_type} does not exist.")
155
+
156
+
157
+ class UNetMidBlock2DCrossAttn(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels: int,
161
+ temb_channels: int,
162
+ dropout: float = 0.0,
163
+ num_layers: int = 1,
164
+ resnet_eps: float = 1e-6,
165
+ resnet_time_scale_shift: str = "default",
166
+ resnet_act_fn: str = "swish",
167
+ resnet_groups: int = 32,
168
+ resnet_pre_norm: bool = True,
169
+ attn_num_head_channels=1,
170
+ output_scale_factor=1.0,
171
+ cross_attention_dim=1280,
172
+ dual_cross_attention=False,
173
+ use_linear_projection=False,
174
+ upcast_attention=False,
175
+ use_gated_attention=False,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.has_cross_attention = True
180
+ self.attn_num_head_channels = attn_num_head_channels
181
+ resnet_groups = resnet_groups if resnet_groups is not None else min(
182
+ in_channels // 4, 32)
183
+
184
+ # there is always at least one resnet
185
+ resnets = [
186
+ ResnetBlock2D(
187
+ in_channels=in_channels,
188
+ out_channels=in_channels,
189
+ temb_channels=temb_channels,
190
+ eps=resnet_eps,
191
+ groups=resnet_groups,
192
+ dropout=dropout,
193
+ time_embedding_norm=resnet_time_scale_shift,
194
+ non_linearity=resnet_act_fn,
195
+ output_scale_factor=output_scale_factor,
196
+ pre_norm=resnet_pre_norm,
197
+ )
198
+ ]
199
+ attentions = []
200
+
201
+ for _ in range(num_layers):
202
+ if not dual_cross_attention:
203
+ attentions.append(
204
+ Transformer2DModel(
205
+ attn_num_head_channels,
206
+ in_channels // attn_num_head_channels,
207
+ in_channels=in_channels,
208
+ num_layers=1,
209
+ cross_attention_dim=cross_attention_dim,
210
+ norm_num_groups=resnet_groups,
211
+ use_linear_projection=use_linear_projection,
212
+ upcast_attention=upcast_attention,
213
+ use_gated_attention=use_gated_attention,
214
+ )
215
+ )
216
+ else:
217
+ attentions.append(
218
+ DualTransformer2DModel(
219
+ attn_num_head_channels,
220
+ in_channels // attn_num_head_channels,
221
+ in_channels=in_channels,
222
+ num_layers=1,
223
+ cross_attention_dim=cross_attention_dim,
224
+ norm_num_groups=resnet_groups,
225
+ )
226
+ )
227
+ resnets.append(
228
+ ResnetBlock2D(
229
+ in_channels=in_channels,
230
+ out_channels=in_channels,
231
+ temb_channels=temb_channels,
232
+ eps=resnet_eps,
233
+ groups=resnet_groups,
234
+ dropout=dropout,
235
+ time_embedding_norm=resnet_time_scale_shift,
236
+ non_linearity=resnet_act_fn,
237
+ output_scale_factor=output_scale_factor,
238
+ pre_norm=resnet_pre_norm,
239
+ )
240
+ )
241
+
242
+ self.attentions = nn.ModuleList(attentions)
243
+ self.resnets = nn.ModuleList(resnets)
244
+
245
+ def forward(
246
+ self,
247
+ hidden_states: torch.FloatTensor,
248
+ temb: Optional[torch.FloatTensor] = None,
249
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
250
+ attention_mask: Optional[torch.FloatTensor] = None,
251
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
252
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
253
+ return_cross_attention_probs: bool = False,
254
+ ) -> torch.FloatTensor:
255
+ hidden_states = self.resnets[0](hidden_states, temb)
256
+ cross_attention_probs_all = []
257
+ base_attn_key = cross_attention_kwargs["attn_key"]
258
+ for attn_key, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
259
+ cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
260
+ hidden_states = attn(
261
+ hidden_states,
262
+ encoder_hidden_states=encoder_hidden_states,
263
+ cross_attention_kwargs=cross_attention_kwargs,
264
+ attention_mask=attention_mask,
265
+ encoder_attention_mask=encoder_attention_mask,
266
+ return_dict=False,
267
+ return_cross_attention_probs=return_cross_attention_probs,
268
+ )
269
+ if return_cross_attention_probs:
270
+ hidden_states, cross_attention_probs = hidden_states
271
+ cross_attention_probs_all.append(cross_attention_probs)
272
+ else:
273
+ hidden_states = hidden_states[0]
274
+ hidden_states = resnet(hidden_states, temb)
275
+
276
+ if return_cross_attention_probs:
277
+ return hidden_states, cross_attention_probs_all
278
+ return hidden_states
279
+
280
+
281
+ class CrossAttnDownBlock2D(nn.Module):
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ temb_channels: int,
287
+ dropout: float = 0.0,
288
+ num_layers: int = 1,
289
+ resnet_eps: float = 1e-6,
290
+ resnet_time_scale_shift: str = "default",
291
+ resnet_act_fn: str = "swish",
292
+ resnet_groups: int = 32,
293
+ resnet_pre_norm: bool = True,
294
+ attn_num_head_channels=1,
295
+ cross_attention_dim=1280,
296
+ output_scale_factor=1.0,
297
+ downsample_padding=1,
298
+ add_downsample=True,
299
+ dual_cross_attention=False,
300
+ use_linear_projection=False,
301
+ only_cross_attention=False,
302
+ upcast_attention=False,
303
+ use_gated_attention=False,
304
+ ):
305
+ super().__init__()
306
+ resnets = []
307
+ attentions = []
308
+
309
+ self.has_cross_attention = True
310
+ self.attn_num_head_channels = attn_num_head_channels
311
+
312
+ for i in range(num_layers):
313
+ in_channels = in_channels if i == 0 else out_channels
314
+ resnets.append(
315
+ ResnetBlock2D(
316
+ in_channels=in_channels,
317
+ out_channels=out_channels,
318
+ temb_channels=temb_channels,
319
+ eps=resnet_eps,
320
+ groups=resnet_groups,
321
+ dropout=dropout,
322
+ time_embedding_norm=resnet_time_scale_shift,
323
+ non_linearity=resnet_act_fn,
324
+ output_scale_factor=output_scale_factor,
325
+ pre_norm=resnet_pre_norm,
326
+ )
327
+ )
328
+ if not dual_cross_attention:
329
+ attentions.append(
330
+ Transformer2DModel(
331
+ attn_num_head_channels,
332
+ out_channels // attn_num_head_channels,
333
+ in_channels=out_channels,
334
+ num_layers=1,
335
+ cross_attention_dim=cross_attention_dim,
336
+ norm_num_groups=resnet_groups,
337
+ use_linear_projection=use_linear_projection,
338
+ only_cross_attention=only_cross_attention,
339
+ upcast_attention=upcast_attention,
340
+ use_gated_attention=use_gated_attention
341
+ )
342
+ )
343
+ else:
344
+ attentions.append(
345
+ DualTransformer2DModel(
346
+ attn_num_head_channels,
347
+ out_channels // attn_num_head_channels,
348
+ in_channels=out_channels,
349
+ num_layers=1,
350
+ cross_attention_dim=cross_attention_dim,
351
+ norm_num_groups=resnet_groups,
352
+ )
353
+ )
354
+ self.attentions = nn.ModuleList(attentions)
355
+ self.resnets = nn.ModuleList(resnets)
356
+
357
+ if add_downsample:
358
+ self.downsamplers = nn.ModuleList(
359
+ [
360
+ Downsample2D(
361
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
362
+ )
363
+ ]
364
+ )
365
+ else:
366
+ self.downsamplers = None
367
+
368
+ self.gradient_checkpointing = False
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states: torch.FloatTensor,
373
+ temb: Optional[torch.FloatTensor] = None,
374
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
375
+ attention_mask: Optional[torch.FloatTensor] = None,
376
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
377
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
378
+ return_cross_attention_probs: bool = False,
379
+ ):
380
+ output_states = ()
381
+ cross_attention_probs_all = []
382
+ base_attn_key = cross_attention_kwargs["attn_key"]
383
+
384
+ for attn_key, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
385
+
386
+ cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
387
+
388
+ if self.training and self.gradient_checkpointing:
389
+
390
+ def create_custom_forward(module, return_dict=None):
391
+ def custom_forward(*inputs):
392
+ if return_dict is not None:
393
+ return module(*inputs, return_dict=return_dict)
394
+ else:
395
+ return module(*inputs)
396
+
397
+ return custom_forward
398
+
399
+ ckpt_kwargs: Dict[str, Any] = {
400
+ "use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
401
+ hidden_states = torch.utils.checkpoint.checkpoint(
402
+ create_custom_forward(resnet),
403
+ hidden_states,
404
+ temb,
405
+ **ckpt_kwargs,
406
+ )
407
+ hidden_states = torch.utils.checkpoint.checkpoint(
408
+ create_custom_forward(attn, return_dict=False),
409
+ hidden_states,
410
+ encoder_hidden_states,
411
+ None, # timestep
412
+ None, # class_labels
413
+ cross_attention_kwargs,
414
+ attention_mask,
415
+ encoder_attention_mask,
416
+ return_cross_attention_probs=return_cross_attention_probs,
417
+ **ckpt_kwargs,
418
+ )
419
+ if return_cross_attention_probs:
420
+ hidden_states, cross_attention_probs = hidden_states
421
+ cross_attention_probs_all.append(cross_attention_probs)
422
+ else:
423
+ hidden_states = hidden_states[0]
424
+ else:
425
+ hidden_states = resnet(hidden_states, temb)
426
+ hidden_states = attn(
427
+ hidden_states,
428
+ encoder_hidden_states=encoder_hidden_states,
429
+ cross_attention_kwargs=cross_attention_kwargs,
430
+ attention_mask=attention_mask,
431
+ encoder_attention_mask=encoder_attention_mask,
432
+ return_dict=False,
433
+ return_cross_attention_probs=return_cross_attention_probs,
434
+ )
435
+ if return_cross_attention_probs:
436
+ hidden_states, cross_attention_probs = hidden_states
437
+ cross_attention_probs_all.append(cross_attention_probs)
438
+ else:
439
+ hidden_states = hidden_states[0]
440
+
441
+ output_states = output_states + (hidden_states,)
442
+
443
+ if self.downsamplers is not None:
444
+ for downsampler in self.downsamplers:
445
+ hidden_states = downsampler(hidden_states)
446
+
447
+ output_states = output_states + (hidden_states,)
448
+
449
+ if return_cross_attention_probs:
450
+ return hidden_states, output_states, cross_attention_probs_all
451
+ return hidden_states, output_states
452
+
453
+
454
+ class DownBlock2D(nn.Module):
455
+ def __init__(
456
+ self,
457
+ in_channels: int,
458
+ out_channels: int,
459
+ temb_channels: int,
460
+ dropout: float = 0.0,
461
+ num_layers: int = 1,
462
+ resnet_eps: float = 1e-6,
463
+ resnet_time_scale_shift: str = "default",
464
+ resnet_act_fn: str = "swish",
465
+ resnet_groups: int = 32,
466
+ resnet_pre_norm: bool = True,
467
+ output_scale_factor=1.0,
468
+ add_downsample=True,
469
+ downsample_padding=1,
470
+ ):
471
+ super().__init__()
472
+ resnets = []
473
+
474
+ for i in range(num_layers):
475
+ in_channels = in_channels if i == 0 else out_channels
476
+ resnets.append(
477
+ ResnetBlock2D(
478
+ in_channels=in_channels,
479
+ out_channels=out_channels,
480
+ temb_channels=temb_channels,
481
+ eps=resnet_eps,
482
+ groups=resnet_groups,
483
+ dropout=dropout,
484
+ time_embedding_norm=resnet_time_scale_shift,
485
+ non_linearity=resnet_act_fn,
486
+ output_scale_factor=output_scale_factor,
487
+ pre_norm=resnet_pre_norm,
488
+ )
489
+ )
490
+
491
+ self.resnets = nn.ModuleList(resnets)
492
+
493
+ if add_downsample:
494
+ self.downsamplers = nn.ModuleList(
495
+ [
496
+ Downsample2D(
497
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
498
+ )
499
+ ]
500
+ )
501
+ else:
502
+ self.downsamplers = None
503
+
504
+ self.gradient_checkpointing = False
505
+
506
+ def forward(self, hidden_states, temb=None):
507
+ output_states = ()
508
+
509
+ for resnet in self.resnets:
510
+ if self.training and self.gradient_checkpointing:
511
+
512
+ def create_custom_forward(module):
513
+ def custom_forward(*inputs):
514
+ return module(*inputs)
515
+
516
+ return custom_forward
517
+
518
+ if is_torch_version(">=", "1.11.0"):
519
+ hidden_states = torch.utils.checkpoint.checkpoint(
520
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
521
+ )
522
+ else:
523
+ hidden_states = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(resnet), hidden_states, temb
525
+ )
526
+ else:
527
+ hidden_states = resnet(hidden_states, temb)
528
+
529
+ output_states = output_states + (hidden_states,)
530
+
531
+ if self.downsamplers is not None:
532
+ for downsampler in self.downsamplers:
533
+ hidden_states = downsampler(hidden_states)
534
+
535
+ output_states = output_states + (hidden_states,)
536
+
537
+ return hidden_states, output_states
538
+
539
+
540
+ class CrossAttnUpBlock2D(nn.Module):
541
+ def __init__(
542
+ self,
543
+ in_channels: int,
544
+ out_channels: int,
545
+ prev_output_channel: int,
546
+ temb_channels: int,
547
+ dropout: float = 0.0,
548
+ num_layers: int = 1,
549
+ resnet_eps: float = 1e-6,
550
+ resnet_time_scale_shift: str = "default",
551
+ resnet_act_fn: str = "swish",
552
+ resnet_groups: int = 32,
553
+ resnet_pre_norm: bool = True,
554
+ attn_num_head_channels=1,
555
+ cross_attention_dim=1280,
556
+ output_scale_factor=1.0,
557
+ add_upsample=True,
558
+ dual_cross_attention=False,
559
+ use_linear_projection=False,
560
+ only_cross_attention=False,
561
+ upcast_attention=False,
562
+ use_gated_attention=False,
563
+ ):
564
+ super().__init__()
565
+ resnets = []
566
+ attentions = []
567
+
568
+ self.has_cross_attention = True
569
+ self.attn_num_head_channels = attn_num_head_channels
570
+
571
+ for i in range(num_layers):
572
+ res_skip_channels = in_channels if (
573
+ i == num_layers - 1) else out_channels
574
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
575
+
576
+ resnets.append(
577
+ ResnetBlock2D(
578
+ in_channels=resnet_in_channels + res_skip_channels,
579
+ out_channels=out_channels,
580
+ temb_channels=temb_channels,
581
+ eps=resnet_eps,
582
+ groups=resnet_groups,
583
+ dropout=dropout,
584
+ time_embedding_norm=resnet_time_scale_shift,
585
+ non_linearity=resnet_act_fn,
586
+ output_scale_factor=output_scale_factor,
587
+ pre_norm=resnet_pre_norm,
588
+ )
589
+ )
590
+ if not dual_cross_attention:
591
+ attentions.append(
592
+ Transformer2DModel(
593
+ attn_num_head_channels,
594
+ out_channels // attn_num_head_channels,
595
+ in_channels=out_channels,
596
+ num_layers=1,
597
+ cross_attention_dim=cross_attention_dim,
598
+ norm_num_groups=resnet_groups,
599
+ use_linear_projection=use_linear_projection,
600
+ only_cross_attention=only_cross_attention,
601
+ upcast_attention=upcast_attention,
602
+ use_gated_attention=use_gated_attention,
603
+ )
604
+ )
605
+ else:
606
+ attentions.append(
607
+ DualTransformer2DModel(
608
+ attn_num_head_channels,
609
+ out_channels // attn_num_head_channels,
610
+ in_channels=out_channels,
611
+ num_layers=1,
612
+ cross_attention_dim=cross_attention_dim,
613
+ norm_num_groups=resnet_groups,
614
+ )
615
+ )
616
+ self.attentions = nn.ModuleList(attentions)
617
+ self.resnets = nn.ModuleList(resnets)
618
+
619
+ if add_upsample:
620
+ self.upsamplers = nn.ModuleList(
621
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
622
+ else:
623
+ self.upsamplers = None
624
+
625
+ self.gradient_checkpointing = False
626
+
627
+ def forward(
628
+ self,
629
+ hidden_states: torch.FloatTensor,
630
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
631
+ temb: Optional[torch.FloatTensor] = None,
632
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
633
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
634
+ upsample_size: Optional[int] = None,
635
+ attention_mask: Optional[torch.FloatTensor] = None,
636
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
637
+ return_cross_attention_probs: bool = False,
638
+ ):
639
+ cross_attention_probs_all = []
640
+ base_attn_key = cross_attention_kwargs["attn_key"]
641
+
642
+ for attn_key, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
643
+ cross_attention_kwargs["attn_key"] = base_attn_key + [attn_key]
644
+
645
+ # pop res hidden states
646
+ res_hidden_states = res_hidden_states_tuple[-1]
647
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
648
+ hidden_states = torch.cat(
649
+ [hidden_states, res_hidden_states], dim=1)
650
+
651
+ if self.training and self.gradient_checkpointing:
652
+
653
+ def create_custom_forward(module, return_dict=None):
654
+ def custom_forward(*inputs):
655
+ if return_dict is not None:
656
+ return module(*inputs, return_dict=return_dict)
657
+ else:
658
+ return module(*inputs)
659
+
660
+ return custom_forward
661
+
662
+ ckpt_kwargs: Dict[str, Any] = {
663
+ "use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
664
+ hidden_states = torch.utils.checkpoint.checkpoint(
665
+ create_custom_forward(resnet),
666
+ hidden_states,
667
+ temb,
668
+ **ckpt_kwargs,
669
+ )
670
+ hidden_states = torch.utils.checkpoint.checkpoint(
671
+ create_custom_forward(attn, return_dict=False),
672
+ hidden_states,
673
+ encoder_hidden_states,
674
+ None, # timestep
675
+ None, # class_labels
676
+ cross_attention_kwargs,
677
+ attention_mask,
678
+ encoder_attention_mask,
679
+ **ckpt_kwargs,
680
+ )
681
+ if return_cross_attention_probs:
682
+ hidden_states, cross_attention_probs = hidden_states
683
+ cross_attention_probs_all.append(cross_attention_probs)
684
+ else:
685
+ hidden_states = hidden_states[0]
686
+ else:
687
+ hidden_states = resnet(hidden_states, temb)
688
+ hidden_states = attn(
689
+ hidden_states,
690
+ encoder_hidden_states=encoder_hidden_states,
691
+ cross_attention_kwargs=cross_attention_kwargs,
692
+ attention_mask=attention_mask,
693
+ encoder_attention_mask=encoder_attention_mask,
694
+ return_dict=False,
695
+ return_cross_attention_probs=return_cross_attention_probs,
696
+ )
697
+ if return_cross_attention_probs:
698
+ hidden_states, cross_attention_probs = hidden_states
699
+ cross_attention_probs_all.append(cross_attention_probs)
700
+ else:
701
+ hidden_states = hidden_states[0]
702
+
703
+ if self.upsamplers is not None:
704
+ for upsampler in self.upsamplers:
705
+ hidden_states = upsampler(hidden_states, upsample_size)
706
+
707
+ if return_cross_attention_probs:
708
+ return hidden_states, cross_attention_probs_all
709
+ return hidden_states
710
+
711
+
712
+ class UpBlock2D(nn.Module):
713
+ def __init__(
714
+ self,
715
+ in_channels: int,
716
+ prev_output_channel: int,
717
+ out_channels: int,
718
+ temb_channels: int,
719
+ dropout: float = 0.0,
720
+ num_layers: int = 1,
721
+ resnet_eps: float = 1e-6,
722
+ resnet_time_scale_shift: str = "default",
723
+ resnet_act_fn: str = "swish",
724
+ resnet_groups: int = 32,
725
+ resnet_pre_norm: bool = True,
726
+ output_scale_factor=1.0,
727
+ add_upsample=True,
728
+ ):
729
+ super().__init__()
730
+ resnets = []
731
+
732
+ for i in range(num_layers):
733
+ res_skip_channels = in_channels if (
734
+ i == num_layers - 1) else out_channels
735
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
736
+
737
+ resnets.append(
738
+ ResnetBlock2D(
739
+ in_channels=resnet_in_channels + res_skip_channels,
740
+ out_channels=out_channels,
741
+ temb_channels=temb_channels,
742
+ eps=resnet_eps,
743
+ groups=resnet_groups,
744
+ dropout=dropout,
745
+ time_embedding_norm=resnet_time_scale_shift,
746
+ non_linearity=resnet_act_fn,
747
+ output_scale_factor=output_scale_factor,
748
+ pre_norm=resnet_pre_norm,
749
+ )
750
+ )
751
+
752
+ self.resnets = nn.ModuleList(resnets)
753
+
754
+ if add_upsample:
755
+ self.upsamplers = nn.ModuleList(
756
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
757
+ else:
758
+ self.upsamplers = None
759
+
760
+ self.gradient_checkpointing = False
761
+
762
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
763
+ for resnet in self.resnets:
764
+ # pop res hidden states
765
+ res_hidden_states = res_hidden_states_tuple[-1]
766
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
767
+ hidden_states = torch.cat(
768
+ [hidden_states, res_hidden_states], dim=1)
769
+
770
+ if self.training and self.gradient_checkpointing:
771
+
772
+ def create_custom_forward(module):
773
+ def custom_forward(*inputs):
774
+ return module(*inputs)
775
+
776
+ return custom_forward
777
+
778
+ if is_torch_version(">=", "1.11.0"):
779
+ hidden_states = torch.utils.checkpoint.checkpoint(
780
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
781
+ )
782
+ else:
783
+ hidden_states = torch.utils.checkpoint.checkpoint(
784
+ create_custom_forward(resnet), hidden_states, temb
785
+ )
786
+ else:
787
+ hidden_states = resnet(hidden_states, temb)
788
+
789
+ if self.upsamplers is not None:
790
+ for upsampler in self.upsamplers:
791
+ hidden_states = upsampler(hidden_states, upsample_size)
792
+
793
+ return hidden_states
models/unet_2d_condition.py ADDED
@@ -0,0 +1,980 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import (
26
+ GaussianFourierProjection,
27
+ TextImageProjection,
28
+ TextImageTimeEmbedding,
29
+ TextTimeEmbedding,
30
+ TimestepEmbedding,
31
+ Timesteps,
32
+ )
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from .unet_2d_blocks import (
35
+ CrossAttnDownBlock2D,
36
+ CrossAttnUpBlock2D,
37
+ DownBlock2D,
38
+ UNetMidBlock2DCrossAttn,
39
+ UpBlock2D,
40
+ get_down_block,
41
+ get_up_block,
42
+ )
43
+ from .attention_processor import AttentionProcessor, AttnProcessor
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet2DConditionOutput(BaseOutput):
51
+ """
52
+ Args:
53
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
54
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
55
+ """
56
+
57
+ sample: torch.FloatTensor
58
+ cross_attention_probs_down: List[Any]
59
+ cross_attention_probs_mid: List[Any]
60
+ cross_attention_probs_up: List[Any]
61
+
62
+
63
+ class FourierEmbedder(nn.Module):
64
+ def __init__(self, num_freqs=64, temperature=100):
65
+ super().__init__()
66
+
67
+ self.num_freqs = num_freqs
68
+ self.temperature = temperature
69
+
70
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
71
+ freq_bands = freq_bands[None, None, None]
72
+ self.register_buffer('freq_bands', freq_bands, persistent=False)
73
+
74
+ def __call__(self, x):
75
+ x = self.freq_bands * x.unsqueeze(-1)
76
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
77
+
78
+
79
+ class PositionNet(nn.Module):
80
+ def __init__(self, positive_len, out_dim, fourier_freqs=8):
81
+ super().__init__()
82
+ self.positive_len = positive_len
83
+ self.out_dim = out_dim
84
+
85
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
86
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
87
+
88
+ self.linears = nn.Sequential(
89
+ nn.Linear(self.positive_len + self.position_dim, 512),
90
+ nn.SiLU(),
91
+ nn.Linear(512, 512),
92
+ nn.SiLU(),
93
+ nn.Linear(512, out_dim),
94
+ )
95
+
96
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
97
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
98
+
99
+ def forward(self, boxes, masks, positive_embeddings):
100
+ masks = masks.unsqueeze(-1)
101
+
102
+ # embedding position (it may includes padding as placeholder)
103
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
104
+
105
+ # learnable null embedding
106
+ positive_null = self.null_positive_feature.view(1, 1, -1)
107
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
108
+
109
+ # replace padding with learnable null embedding
110
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
111
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
112
+
113
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
114
+ return objs
115
+
116
+
117
+
118
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
119
+ r"""
120
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
121
+ and returns sample shaped output.
122
+
123
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
124
+ implements for all the models (such as downloading or saving, etc.)
125
+
126
+ Parameters:
127
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
128
+ Height and width of input/output sample.
129
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
130
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
131
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
132
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
133
+ Whether to flip the sin to cos in the time embedding.
134
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
135
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
136
+ The tuple of downsample blocks to use.
137
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
138
+ The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
139
+ mid block layer if `None`.
140
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
141
+ The tuple of upsample blocks to use.
142
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
143
+ Whether to include self-attention in the basic transformer blocks, see
144
+ [`~models.attention.BasicTransformerBlock`].
145
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
146
+ The tuple of output channels for each block.
147
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
148
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
149
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
150
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
151
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
152
+ If `None`, it will skip the normalization and activation layers in post-processing
153
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
154
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
155
+ The dimension of the cross attention features.
156
+ encoder_hid_dim (`int`, *optional*, defaults to None):
157
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
158
+ dimension to `cross_attention_dim`.
159
+ encoder_hid_dim_type (`str`, *optional*, defaults to None):
160
+ If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text
161
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
162
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
163
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
164
+ for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
165
+ class_embed_type (`str`, *optional*, defaults to None):
166
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
167
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
168
+ addition_embed_type (`str`, *optional*, defaults to None):
169
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
170
+ "text". "text" will use the `TextTimeEmbedding` layer.
171
+ num_class_embeds (`int`, *optional*, defaults to None):
172
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
173
+ class conditioning with `class_embed_type` equal to `None`.
174
+ time_embedding_type (`str`, *optional*, default to `positional`):
175
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
176
+ time_embedding_dim (`int`, *optional*, default to `None`):
177
+ An optional override for the dimension of the projected time embedding.
178
+ time_embedding_act_fn (`str`, *optional*, default to `None`):
179
+ Optional activation function to use on the time embeddings only one time before they as passed to the rest
180
+ of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`.
181
+ timestep_post_act (`str, *optional*, default to `None`):
182
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
183
+ time_cond_proj_dim (`int`, *optional*, default to `None`):
184
+ The dimension of `cond_proj` layer in timestep embedding.
185
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
186
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
187
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
188
+ using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
189
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
190
+ embeddings with the class embeddings.
191
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
192
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
193
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the
194
+ `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will
195
+ default to `False`.
196
+ """
197
+
198
+ _supports_gradient_checkpointing = True
199
+
200
+ @register_to_config
201
+ def __init__(
202
+ self,
203
+ sample_size: Optional[int] = None,
204
+ in_channels: int = 4,
205
+ out_channels: int = 4,
206
+ center_input_sample: bool = False,
207
+ flip_sin_to_cos: bool = True,
208
+ freq_shift: int = 0,
209
+ down_block_types: Tuple[str] = (
210
+ "CrossAttnDownBlock2D",
211
+ "CrossAttnDownBlock2D",
212
+ "CrossAttnDownBlock2D",
213
+ "DownBlock2D",
214
+ ),
215
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
216
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
217
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
218
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
219
+ layers_per_block: Union[int, Tuple[int]] = 2,
220
+ downsample_padding: int = 1,
221
+ mid_block_scale_factor: float = 1,
222
+ act_fn: str = "silu",
223
+ norm_num_groups: Optional[int] = 32,
224
+ norm_eps: float = 1e-5,
225
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
226
+ encoder_hid_dim: Optional[int] = None,
227
+ encoder_hid_dim_type: Optional[str] = None,
228
+ attention_head_dim: Union[int, Tuple[int]] = 8,
229
+ dual_cross_attention: bool = False,
230
+ use_linear_projection: bool = False,
231
+ class_embed_type: Optional[str] = None,
232
+ addition_embed_type: Optional[str] = None,
233
+ num_class_embeds: Optional[int] = None,
234
+ upcast_attention: bool = False,
235
+ resnet_time_scale_shift: str = "default",
236
+ resnet_skip_time_act: bool = False,
237
+ resnet_out_scale_factor: int = 1.0,
238
+ time_embedding_type: str = "positional",
239
+ time_embedding_dim: Optional[int] = None,
240
+ time_embedding_act_fn: Optional[str] = None,
241
+ timestep_post_act: Optional[str] = None,
242
+ time_cond_proj_dim: Optional[int] = None,
243
+ conv_in_kernel: int = 3,
244
+ conv_out_kernel: int = 3,
245
+ projection_class_embeddings_input_dim: Optional[int] = None,
246
+ class_embeddings_concat: bool = False,
247
+ mid_block_only_cross_attention: Optional[bool] = None,
248
+ cross_attention_norm: Optional[str] = None,
249
+ addition_embed_type_num_heads=64,
250
+ use_gated_attention: bool = False,
251
+ ):
252
+ super().__init__()
253
+
254
+ self.sample_size = sample_size
255
+
256
+ # Check inputs
257
+ if len(down_block_types) != len(up_block_types):
258
+ raise ValueError(
259
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
260
+ )
261
+
262
+ if len(block_out_channels) != len(down_block_types):
263
+ raise ValueError(
264
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
265
+ )
266
+
267
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
268
+ raise ValueError(
269
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
270
+ )
271
+
272
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
273
+ raise ValueError(
274
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
278
+ raise ValueError(
279
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
280
+ )
281
+
282
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
283
+ raise ValueError(
284
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
285
+ )
286
+
287
+ # input
288
+ conv_in_padding = (conv_in_kernel - 1) // 2
289
+ self.conv_in = nn.Conv2d(
290
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
291
+ )
292
+
293
+ # time
294
+ if time_embedding_type == "fourier":
295
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
296
+ if time_embed_dim % 2 != 0:
297
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
298
+ self.time_proj = GaussianFourierProjection(
299
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
300
+ )
301
+ timestep_input_dim = time_embed_dim
302
+ elif time_embedding_type == "positional":
303
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
304
+
305
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
306
+ timestep_input_dim = block_out_channels[0]
307
+ else:
308
+ raise ValueError(
309
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
310
+ )
311
+
312
+ self.time_embedding = TimestepEmbedding(
313
+ timestep_input_dim,
314
+ time_embed_dim,
315
+ act_fn=act_fn,
316
+ post_act_fn=timestep_post_act,
317
+ cond_proj_dim=time_cond_proj_dim,
318
+ )
319
+
320
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
321
+ encoder_hid_dim_type = "text_proj"
322
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
323
+
324
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
325
+ raise ValueError(
326
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
327
+ )
328
+
329
+ if encoder_hid_dim_type == "text_proj":
330
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
331
+ elif encoder_hid_dim_type == "text_image_proj":
332
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
333
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
334
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
335
+ self.encoder_hid_proj = TextImageProjection(
336
+ text_embed_dim=encoder_hid_dim,
337
+ image_embed_dim=cross_attention_dim,
338
+ cross_attention_dim=cross_attention_dim,
339
+ )
340
+
341
+ elif encoder_hid_dim_type is not None:
342
+ raise ValueError(
343
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
344
+ )
345
+ else:
346
+ self.encoder_hid_proj = None
347
+
348
+ # class embedding
349
+ if class_embed_type is None and num_class_embeds is not None:
350
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
351
+ elif class_embed_type == "timestep":
352
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
353
+ elif class_embed_type == "identity":
354
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
355
+ elif class_embed_type == "projection":
356
+ if projection_class_embeddings_input_dim is None:
357
+ raise ValueError(
358
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
359
+ )
360
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
361
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
362
+ # 2. it projects from an arbitrary input dimension.
363
+ #
364
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
365
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
366
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
367
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
368
+ elif class_embed_type == "simple_projection":
369
+ if projection_class_embeddings_input_dim is None:
370
+ raise ValueError(
371
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
372
+ )
373
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
374
+ else:
375
+ self.class_embedding = None
376
+
377
+ if addition_embed_type == "text":
378
+ if encoder_hid_dim is not None:
379
+ text_time_embedding_from_dim = encoder_hid_dim
380
+ else:
381
+ text_time_embedding_from_dim = cross_attention_dim
382
+
383
+ self.add_embedding = TextTimeEmbedding(
384
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
385
+ )
386
+ elif addition_embed_type == "text_image":
387
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
388
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
389
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
390
+ self.add_embedding = TextImageTimeEmbedding(
391
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
392
+ )
393
+ elif addition_embed_type is not None:
394
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
395
+
396
+ if time_embedding_act_fn is None:
397
+ self.time_embed_act = None
398
+ elif time_embedding_act_fn == "swish":
399
+ self.time_embed_act = lambda x: F.silu(x)
400
+ elif time_embedding_act_fn == "mish":
401
+ self.time_embed_act = nn.Mish()
402
+ elif time_embedding_act_fn == "silu":
403
+ self.time_embed_act = nn.SiLU()
404
+ elif time_embedding_act_fn == "gelu":
405
+ self.time_embed_act = nn.GELU()
406
+ else:
407
+ raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}")
408
+
409
+ self.down_blocks = nn.ModuleList([])
410
+ self.up_blocks = nn.ModuleList([])
411
+
412
+ if isinstance(only_cross_attention, bool):
413
+ if mid_block_only_cross_attention is None:
414
+ mid_block_only_cross_attention = only_cross_attention
415
+
416
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
417
+
418
+ if mid_block_only_cross_attention is None:
419
+ mid_block_only_cross_attention = False
420
+
421
+ if isinstance(attention_head_dim, int):
422
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
423
+
424
+ if isinstance(cross_attention_dim, int):
425
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
426
+ else:
427
+ assert not use_gated_attention, f"use_gated_attention is not supported with varying cross_attention_dim: {cross_attention_dim}"
428
+
429
+ if isinstance(layers_per_block, int):
430
+ layers_per_block = [layers_per_block] * len(down_block_types)
431
+
432
+ if class_embeddings_concat:
433
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
434
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
435
+ # regular time embeddings
436
+ blocks_time_embed_dim = time_embed_dim * 2
437
+ else:
438
+ blocks_time_embed_dim = time_embed_dim
439
+
440
+ # down
441
+ output_channel = block_out_channels[0]
442
+ for i, down_block_type in enumerate(down_block_types):
443
+ input_channel = output_channel
444
+ output_channel = block_out_channels[i]
445
+ is_final_block = i == len(block_out_channels) - 1
446
+
447
+ down_block = get_down_block(
448
+ down_block_type,
449
+ num_layers=layers_per_block[i],
450
+ in_channels=input_channel,
451
+ out_channels=output_channel,
452
+ temb_channels=blocks_time_embed_dim,
453
+ add_downsample=not is_final_block,
454
+ resnet_eps=norm_eps,
455
+ resnet_act_fn=act_fn,
456
+ resnet_groups=norm_num_groups,
457
+ cross_attention_dim=cross_attention_dim[i],
458
+ attn_num_head_channels=attention_head_dim[i],
459
+ downsample_padding=downsample_padding,
460
+ dual_cross_attention=dual_cross_attention,
461
+ use_linear_projection=use_linear_projection,
462
+ only_cross_attention=only_cross_attention[i],
463
+ upcast_attention=upcast_attention,
464
+ resnet_time_scale_shift=resnet_time_scale_shift,
465
+ resnet_skip_time_act=resnet_skip_time_act,
466
+ resnet_out_scale_factor=resnet_out_scale_factor,
467
+ cross_attention_norm=cross_attention_norm,
468
+ use_gated_attention=use_gated_attention,
469
+ )
470
+ self.down_blocks.append(down_block)
471
+
472
+ # mid
473
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
474
+ self.mid_block = UNetMidBlock2DCrossAttn(
475
+ in_channels=block_out_channels[-1],
476
+ temb_channels=blocks_time_embed_dim,
477
+ resnet_eps=norm_eps,
478
+ resnet_act_fn=act_fn,
479
+ output_scale_factor=mid_block_scale_factor,
480
+ resnet_time_scale_shift=resnet_time_scale_shift,
481
+ cross_attention_dim=cross_attention_dim[-1],
482
+ attn_num_head_channels=attention_head_dim[-1],
483
+ resnet_groups=norm_num_groups,
484
+ dual_cross_attention=dual_cross_attention,
485
+ use_linear_projection=use_linear_projection,
486
+ upcast_attention=upcast_attention,
487
+ use_gated_attention=use_gated_attention,
488
+ )
489
+ elif mid_block_type is None:
490
+ self.mid_block = None
491
+ else:
492
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
493
+
494
+ # count how many layers upsample the images
495
+ self.num_upsamplers = 0
496
+
497
+ # up
498
+ reversed_block_out_channels = list(reversed(block_out_channels))
499
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
500
+ reversed_layers_per_block = list(reversed(layers_per_block))
501
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
502
+ only_cross_attention = list(reversed(only_cross_attention))
503
+
504
+ output_channel = reversed_block_out_channels[0]
505
+ for i, up_block_type in enumerate(up_block_types):
506
+ is_final_block = i == len(block_out_channels) - 1
507
+
508
+ prev_output_channel = output_channel
509
+ output_channel = reversed_block_out_channels[i]
510
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
511
+
512
+ # add upsample block for all BUT final layer
513
+ if not is_final_block:
514
+ add_upsample = True
515
+ self.num_upsamplers += 1
516
+ else:
517
+ add_upsample = False
518
+
519
+ up_block = get_up_block(
520
+ up_block_type,
521
+ num_layers=reversed_layers_per_block[i] + 1,
522
+ in_channels=input_channel,
523
+ out_channels=output_channel,
524
+ prev_output_channel=prev_output_channel,
525
+ temb_channels=blocks_time_embed_dim,
526
+ add_upsample=add_upsample,
527
+ resnet_eps=norm_eps,
528
+ resnet_act_fn=act_fn,
529
+ resnet_groups=norm_num_groups,
530
+ cross_attention_dim=reversed_cross_attention_dim[i],
531
+ attn_num_head_channels=reversed_attention_head_dim[i],
532
+ dual_cross_attention=dual_cross_attention,
533
+ use_linear_projection=use_linear_projection,
534
+ only_cross_attention=only_cross_attention[i],
535
+ upcast_attention=upcast_attention,
536
+ resnet_time_scale_shift=resnet_time_scale_shift,
537
+ resnet_skip_time_act=resnet_skip_time_act,
538
+ resnet_out_scale_factor=resnet_out_scale_factor,
539
+ cross_attention_norm=cross_attention_norm,
540
+ use_gated_attention=use_gated_attention,
541
+ )
542
+ self.up_blocks.append(up_block)
543
+ prev_output_channel = output_channel
544
+
545
+ # out
546
+ if norm_num_groups is not None:
547
+ self.conv_norm_out = nn.GroupNorm(
548
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
549
+ )
550
+
551
+ if act_fn == "swish":
552
+ self.conv_act = lambda x: F.silu(x)
553
+ elif act_fn == "mish":
554
+ self.conv_act = nn.Mish()
555
+ elif act_fn == "silu":
556
+ self.conv_act = nn.SiLU()
557
+ elif act_fn == "gelu":
558
+ self.conv_act = nn.GELU()
559
+ else:
560
+ raise ValueError(f"Unsupported activation function: {act_fn}")
561
+
562
+ else:
563
+ self.conv_norm_out = None
564
+ self.conv_act = None
565
+
566
+ conv_out_padding = (conv_out_kernel - 1) // 2
567
+ self.conv_out = nn.Conv2d(
568
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
569
+ )
570
+
571
+ if use_gated_attention:
572
+ self.position_net = PositionNet(positive_len=768, out_dim=cross_attention_dim[-1])
573
+
574
+
575
+ @property
576
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
577
+ r"""
578
+ Returns:
579
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
580
+ indexed by its weight name.
581
+ """
582
+ # set recursively
583
+ processors = {}
584
+
585
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
586
+ if hasattr(module, "set_processor"):
587
+ processors[f"{name}.processor"] = module.processor
588
+
589
+ for sub_name, child in module.named_children():
590
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
591
+
592
+ return processors
593
+
594
+ for name, module in self.named_children():
595
+ fn_recursive_add_processors(name, module, processors)
596
+
597
+ return processors
598
+
599
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
600
+ r"""
601
+ Parameters:
602
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
603
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
604
+ of **all** `Attention` layers.
605
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
606
+
607
+ """
608
+ count = len(self.attn_processors.keys())
609
+
610
+ if isinstance(processor, dict) and len(processor) != count:
611
+ raise ValueError(
612
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
613
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
614
+ )
615
+
616
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
617
+ if hasattr(module, "set_processor"):
618
+ if not isinstance(processor, dict):
619
+ module.set_processor(processor)
620
+ else:
621
+ module.set_processor(processor.pop(f"{name}.processor"))
622
+
623
+ for sub_name, child in module.named_children():
624
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
625
+
626
+ for name, module in self.named_children():
627
+ fn_recursive_attn_processor(name, module, processor)
628
+
629
+ def set_default_attn_processor(self):
630
+ """
631
+ Disables custom attention processors and sets the default attention implementation.
632
+ """
633
+ self.set_attn_processor(AttnProcessor())
634
+
635
+ def set_attention_slice(self, slice_size):
636
+ r"""
637
+ Enable sliced attention computation.
638
+
639
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
640
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
641
+
642
+ Args:
643
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
644
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
645
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
646
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
647
+ must be a multiple of `slice_size`.
648
+ """
649
+ sliceable_head_dims = []
650
+
651
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
652
+ if hasattr(module, "set_attention_slice"):
653
+ sliceable_head_dims.append(module.sliceable_head_dim)
654
+
655
+ for child in module.children():
656
+ fn_recursive_retrieve_sliceable_dims(child)
657
+
658
+ # retrieve number of attention layers
659
+ for module in self.children():
660
+ fn_recursive_retrieve_sliceable_dims(module)
661
+
662
+ num_sliceable_layers = len(sliceable_head_dims)
663
+
664
+ if slice_size == "auto":
665
+ # half the attention head size is usually a good trade-off between
666
+ # speed and memory
667
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
668
+ elif slice_size == "max":
669
+ # make smallest slice possible
670
+ slice_size = num_sliceable_layers * [1]
671
+
672
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
673
+
674
+ if len(slice_size) != len(sliceable_head_dims):
675
+ raise ValueError(
676
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
677
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
678
+ )
679
+
680
+ for i in range(len(slice_size)):
681
+ size = slice_size[i]
682
+ dim = sliceable_head_dims[i]
683
+ if size is not None and size > dim:
684
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
685
+
686
+ # Recursively walk through all the children.
687
+ # Any children which exposes the set_attention_slice method
688
+ # gets the message
689
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
690
+ if hasattr(module, "set_attention_slice"):
691
+ module.set_attention_slice(slice_size.pop())
692
+
693
+ for child in module.children():
694
+ fn_recursive_set_attention_slice(child, slice_size)
695
+
696
+ reversed_slice_size = list(reversed(slice_size))
697
+ for module in self.children():
698
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
699
+
700
+ def _set_gradient_checkpointing(self, module, value=False):
701
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
702
+ module.gradient_checkpointing = value
703
+
704
+ def forward(
705
+ self,
706
+ sample: torch.FloatTensor,
707
+ timestep: Union[torch.Tensor, float, int],
708
+ encoder_hidden_states: torch.Tensor,
709
+ class_labels: Optional[torch.Tensor] = None,
710
+ timestep_cond: Optional[torch.Tensor] = None,
711
+ attention_mask: Optional[torch.Tensor] = None,
712
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
713
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
714
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
715
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
716
+ encoder_attention_mask: Optional[torch.Tensor] = None,
717
+ return_dict: bool = True,
718
+ return_cross_attention_probs: bool = False
719
+ ) -> Union[UNet2DConditionOutput, Tuple]:
720
+ r"""
721
+ Args:
722
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
723
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
724
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
725
+ encoder_attention_mask (`torch.Tensor`):
726
+ (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
727
+ discard. Mask will be converted into a bias, which adds large negative values to attention scores
728
+ corresponding to "discard" tokens.
729
+ return_dict (`bool`, *optional*, defaults to `True`):
730
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
731
+ cross_attention_kwargs (`dict`, *optional*):
732
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
733
+ `self.processor` in
734
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
735
+ added_cond_kwargs (`dict`, *optional*):
736
+ A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time
737
+ embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and
738
+ `addition_embed_type` for more information.
739
+
740
+ Returns:
741
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
742
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
743
+ returning a tuple, the first element is the sample tensor.
744
+ """
745
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
746
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
747
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
748
+ # on the fly if necessary.
749
+ default_overall_up_factor = 2**self.num_upsamplers
750
+
751
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
752
+ forward_upsample_size = False
753
+ upsample_size = None
754
+
755
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
756
+ logger.info("Forward upsample size to force interpolation output size.")
757
+ forward_upsample_size = True
758
+
759
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
760
+ # expects mask of shape:
761
+ # [batch, key_tokens]
762
+ # adds singleton query_tokens dimension:
763
+ # [batch, 1, key_tokens]
764
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
765
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
766
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
767
+ if attention_mask is not None:
768
+ # assume that mask is expressed as:
769
+ # (1 = keep, 0 = discard)
770
+ # convert mask into a bias that can be added to attention scores:
771
+ # (keep = +0, discard = -10000.0)
772
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
773
+ attention_mask = attention_mask.unsqueeze(1)
774
+
775
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
776
+ if encoder_attention_mask is not None:
777
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
778
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
779
+
780
+ # 0. center input if necessary
781
+ if self.config.center_input_sample:
782
+ sample = 2 * sample - 1.0
783
+
784
+ # 1. time
785
+ timesteps = timestep
786
+ if not torch.is_tensor(timesteps):
787
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
788
+ # This would be a good case for the `match` statement (Python 3.10+)
789
+ is_mps = sample.device.type == "mps"
790
+ if isinstance(timestep, float):
791
+ dtype = torch.float32 if is_mps else torch.float64
792
+ else:
793
+ dtype = torch.int32 if is_mps else torch.int64
794
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
795
+ elif len(timesteps.shape) == 0:
796
+ timesteps = timesteps[None].to(sample.device)
797
+
798
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
799
+ timesteps = timesteps.expand(sample.shape[0])
800
+
801
+ t_emb = self.time_proj(timesteps)
802
+
803
+ # `Timesteps` does not contain any weights and will always return f32 tensors
804
+ # but time_embedding might actually be running in fp16. so we need to cast here.
805
+ # there might be better ways to encapsulate this.
806
+ t_emb = t_emb.to(dtype=sample.dtype)
807
+
808
+ emb = self.time_embedding(t_emb, timestep_cond)
809
+
810
+ if self.class_embedding is not None:
811
+ if class_labels is None:
812
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
813
+
814
+ if self.config.class_embed_type == "timestep":
815
+ class_labels = self.time_proj(class_labels)
816
+
817
+ # `Timesteps` does not contain any weights and will always return f32 tensors
818
+ # there might be better ways to encapsulate this.
819
+ class_labels = class_labels.to(dtype=sample.dtype)
820
+
821
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
822
+
823
+ if self.config.class_embeddings_concat:
824
+ emb = torch.cat([emb, class_emb], dim=-1)
825
+ else:
826
+ emb = emb + class_emb
827
+
828
+ if self.config.addition_embed_type == "text":
829
+ aug_emb = self.add_embedding(encoder_hidden_states)
830
+ emb = emb + aug_emb
831
+ elif self.config.addition_embed_type == "text_image":
832
+ # Kadinsky 2.1 - style
833
+ if "image_embeds" not in added_cond_kwargs:
834
+ raise ValueError(
835
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
836
+ )
837
+
838
+ image_embs = added_cond_kwargs.get("image_embeds")
839
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
840
+
841
+ aug_emb = self.add_embedding(text_embs, image_embs)
842
+ emb = emb + aug_emb
843
+
844
+ if self.time_embed_act is not None:
845
+ emb = self.time_embed_act(emb)
846
+
847
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
848
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
849
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
850
+ # Kadinsky 2.1 - style
851
+ if "image_embeds" not in added_cond_kwargs:
852
+ raise ValueError(
853
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
854
+ )
855
+
856
+ image_embeds = added_cond_kwargs.get("image_embeds")
857
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
858
+
859
+ # 2. pre-process
860
+ sample = self.conv_in(sample)
861
+
862
+ # 2.5 GLIGEN position net
863
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get('gligen', None) is not None:
864
+ cross_attention_kwargs = cross_attention_kwargs.copy()
865
+ cross_attention_kwargs['gligen'] = {
866
+ 'objs': self.position_net(
867
+ boxes=cross_attention_kwargs['gligen']['boxes'],
868
+ masks=cross_attention_kwargs['gligen']['masks'],
869
+ positive_embeddings=cross_attention_kwargs['gligen']['positive_embeddings']
870
+ ),
871
+ 'fuser_attn_kwargs': cross_attention_kwargs['gligen'].get('fuser_attn_kwargs', {})
872
+ }
873
+
874
+ # 3. down
875
+ down_block_res_samples = (sample,)
876
+ cross_attention_probs_down = []
877
+ if cross_attention_kwargs is None:
878
+ cross_attention_kwargs = {}
879
+
880
+ for i, downsample_block in enumerate(self.down_blocks):
881
+ cross_attention_kwargs["attn_key"] = ["down", i]
882
+
883
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
884
+ downsample_block_output = downsample_block(
885
+ hidden_states=sample,
886
+ temb=emb,
887
+ encoder_hidden_states=encoder_hidden_states,
888
+ attention_mask=attention_mask,
889
+ cross_attention_kwargs=cross_attention_kwargs,
890
+ encoder_attention_mask=encoder_attention_mask,
891
+ return_cross_attention_probs=return_cross_attention_probs,
892
+ )
893
+ if return_cross_attention_probs:
894
+ sample, res_samples, cross_attention_probs = downsample_block_output
895
+ cross_attention_probs_down.append(cross_attention_probs)
896
+ else:
897
+ sample, res_samples = downsample_block_output
898
+ else:
899
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
900
+
901
+ down_block_res_samples += res_samples
902
+
903
+ if down_block_additional_residuals is not None:
904
+ new_down_block_res_samples = ()
905
+
906
+ for down_block_res_sample, down_block_additional_residual in zip(
907
+ down_block_res_samples, down_block_additional_residuals
908
+ ):
909
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
910
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
911
+
912
+ down_block_res_samples = new_down_block_res_samples
913
+
914
+ # 4. mid
915
+ cross_attention_probs_mid = []
916
+ if self.mid_block is not None:
917
+ cross_attention_kwargs["attn_key"] = ["mid", 0]
918
+
919
+ sample = self.mid_block(
920
+ sample,
921
+ emb,
922
+ encoder_hidden_states=encoder_hidden_states,
923
+ attention_mask=attention_mask,
924
+ cross_attention_kwargs=cross_attention_kwargs,
925
+ encoder_attention_mask=encoder_attention_mask,
926
+ return_cross_attention_probs=return_cross_attention_probs,
927
+ )
928
+ if return_cross_attention_probs:
929
+ sample, cross_attention_probs = sample
930
+ cross_attention_probs_mid.append(cross_attention_probs)
931
+
932
+
933
+ if mid_block_additional_residual is not None:
934
+ sample = sample + mid_block_additional_residual
935
+
936
+ cross_attention_probs_up = []
937
+ # 5. up
938
+ for i, upsample_block in enumerate(self.up_blocks):
939
+ cross_attention_kwargs["attn_key"] = ["up", i]
940
+
941
+ is_final_block = i == len(self.up_blocks) - 1
942
+
943
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
944
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
945
+
946
+ # if we have not reached the final block and need to forward the
947
+ # upsample size, we do it here
948
+ if not is_final_block and forward_upsample_size:
949
+ upsample_size = down_block_res_samples[-1].shape[2:]
950
+
951
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
952
+ sample = upsample_block(
953
+ hidden_states=sample,
954
+ temb=emb,
955
+ res_hidden_states_tuple=res_samples,
956
+ encoder_hidden_states=encoder_hidden_states,
957
+ cross_attention_kwargs=cross_attention_kwargs,
958
+ upsample_size=upsample_size,
959
+ attention_mask=attention_mask,
960
+ encoder_attention_mask=encoder_attention_mask,
961
+ return_cross_attention_probs=return_cross_attention_probs,
962
+ )
963
+ if return_cross_attention_probs:
964
+ sample, cross_attention_probs = sample
965
+ cross_attention_probs_up.append(cross_attention_probs)
966
+ else:
967
+ sample = upsample_block(
968
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
969
+ )
970
+
971
+ # 6. post-process
972
+ if self.conv_norm_out:
973
+ sample = self.conv_norm_out(sample)
974
+ sample = self.conv_act(sample)
975
+ sample = self.conv_out(sample)
976
+
977
+ if not return_dict:
978
+ return (sample,)
979
+
980
+ return UNet2DConditionOutput(sample=sample, cross_attention_probs_down=cross_attention_probs_down, cross_attention_probs_mid=cross_attention_probs_mid, cross_attention_probs_up=cross_attention_probs_up)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ numpy
3
+ scipy
4
+ torch==2.0.0
5
+ diffusers==0.17.0
6
+ transformers==4.29.2
7
+ opencv-python==4.7.0.72
8
+ opencv-contrib-python==4.7.0.72
9
+ inflect==6.0.4
10
+ easydict
11
+ accelerate==0.18.0
shared.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import load_sd, sam
2
+
3
+ use_fp16 = False
4
+ use_dpm = True
5
+
6
+ sd_key = "gligen/diffusers-generation-text-box"
7
+
8
+ print(f"Using SD: {sd_key}")
9
+ model_dict = load_sd(key=sd_key, use_fp16=use_fp16, use_dpm_multistep_scheduler=use_dpm, load_inverse_scheduler=False)
10
+
11
+ sam_model_dict = sam.load_sam()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
utils/latents.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from . import utils
4
+ from utils import torch_device
5
+ import matplotlib.pyplot as plt
6
+
7
+ def get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype):
8
+ """
9
+ in_channels: often obtained with `unet.config.in_channels`
10
+ """
11
+ # Obtain with torch.float32 and cast to float16 if needed
12
+ # Directly obtaining latents in float16 will lead to different latents
13
+ latents_base = torch.randn(
14
+ (batch_size, in_channels, height // 8, width // 8),
15
+ generator=generator, dtype=dtype
16
+ ).to(torch_device, dtype=dtype)
17
+
18
+ return latents_base
19
+
20
+ def get_scaled_latents(batch_size, in_channels, height, width, generator, dtype, scheduler):
21
+ latents_base = get_unscaled_latents(batch_size, in_channels, height, width, generator, dtype)
22
+ latents_base = latents_base * scheduler.init_noise_sigma
23
+ return latents_base
24
+
25
+ def blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=0.01):
26
+ """
27
+ in_channels: often obtained with `unet.config.in_channels`
28
+ """
29
+ assert not torch.allclose(latents_bg, latents_fg), "latents_bg should be independent with latents_fg"
30
+
31
+ dtype = latents_bg.dtype
32
+ latents = latents_bg * (1. - fg_mask) + (latents_bg * np.sqrt(1. - fg_blending_ratio) + latents_fg * np.sqrt(fg_blending_ratio)) * fg_mask
33
+ latents = latents.to(dtype=dtype)
34
+
35
+ return latents
36
+
37
+ @torch.no_grad()
38
+ def compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, latents_bg=None, bg_seed=None, compose_box_to_bg=True):
39
+ unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
40
+
41
+ if latents_bg is None:
42
+ generator = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
43
+ latents_bg = get_scaled_latents(overall_batch_size, unet.config.in_channels, height, width, generator, dtype, scheduler)
44
+
45
+ # Other than t=T (idx=0), we only have masked latents. This is to prevent accidentally loading from non-masked part. Use same mask as the one used to compose the latents.
46
+ composed_latents = torch.zeros((num_inference_steps + 1, *latents_bg.shape), dtype=dtype)
47
+ composed_latents[0] = latents_bg
48
+
49
+ foreground_indices = torch.zeros(latents_bg.shape[-2:], dtype=torch.long)
50
+
51
+ mask_size = np.array([mask_tensor.sum().item() for mask_tensor in mask_tensor_list])
52
+ # Compose the largest mask first
53
+ mask_order = np.argsort(-mask_size)
54
+
55
+ if compose_box_to_bg:
56
+ # This has two functionalities:
57
+ # 1. copies the right initial latents from the right place (for centered so generation), 2. copies the right initial latents (since we have foreground blending) for centered/original so generation.
58
+ for mask_idx in mask_order:
59
+ latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
60
+
61
+ # Note: need to be careful to not copy from zeros due to shifting.
62
+ mask_tensor = utils.binary_mask_to_box_mask(mask_tensor, to_device=False)
63
+
64
+ mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
65
+ composed_latents[0] = composed_latents[0] * (1. - mask_tensor_expanded) + latents_all[0] * mask_tensor_expanded
66
+
67
+ # This is still needed with `compose_box_to_bg` to ensure the foreground latent is still visible and to compute foreground indices.
68
+ for mask_idx in mask_order:
69
+ latents_all, mask_tensor = latents_all_list[mask_idx], mask_tensor_list[mask_idx]
70
+ foreground_indices = foreground_indices * (~mask_tensor) + (mask_idx + 1) * mask_tensor
71
+ mask_tensor_expanded = mask_tensor[None, None, None, ...].to(dtype)
72
+ composed_latents = composed_latents * (1. - mask_tensor_expanded) + latents_all * mask_tensor_expanded
73
+
74
+ composed_latents, foreground_indices = composed_latents.to(torch_device), foreground_indices.to(torch_device)
75
+ return composed_latents, foreground_indices
76
+
77
+ def align_with_bboxes(latents_all_list, mask_tensor_list, bboxes, horizontal_shift_only=False):
78
+ """
79
+ Each offset in `offset_list` is `(x_offset, y_offset)` (normalized).
80
+ """
81
+ new_latents_all_list, new_mask_tensor_list, offset_list = [], [], []
82
+ for latents_all, mask_tensor, bbox in zip(latents_all_list, mask_tensor_list, bboxes):
83
+ x_src_center, y_src_center = utils.binary_mask_to_center(mask_tensor, normalize=True)
84
+ x_min_dest, y_min_dest, x_max_dest, y_max_dest = bbox
85
+ x_dest_center, y_dest_center = (x_min_dest + x_max_dest) / 2, (y_min_dest + y_max_dest) / 2
86
+ # print("src (x,y):", x_src_center, y_src_center, "dest (x,y):", x_dest_center, y_dest_center)
87
+ x_offset, y_offset = x_dest_center - x_src_center, y_dest_center - y_src_center
88
+ if horizontal_shift_only:
89
+ y_offset = 0.
90
+ offset = x_offset, y_offset
91
+ latents_all = utils.shift_tensor(latents_all, x_offset, y_offset, offset_normalized=True)
92
+ mask_tensor = utils.shift_tensor(mask_tensor, x_offset, y_offset, offset_normalized=True)
93
+ new_latents_all_list.append(latents_all)
94
+ new_mask_tensor_list.append(mask_tensor)
95
+ offset_list.append(offset)
96
+
97
+ return new_latents_all_list, new_mask_tensor_list, offset_list
98
+
99
+ @torch.no_grad()
100
+ def compose_latents_with_alignment(
101
+ model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width,
102
+ align_with_overall_bboxes=True, overall_bboxes=None, horizontal_shift_only=False, **kwargs
103
+ ):
104
+ if align_with_overall_bboxes and len(latents_all_list):
105
+ expanded_overall_bboxes = utils.expand_overall_bboxes(overall_bboxes)
106
+ latents_all_list, mask_tensor_list, offset_list = align_with_bboxes(latents_all_list, mask_tensor_list, bboxes=expanded_overall_bboxes, horizontal_shift_only=horizontal_shift_only)
107
+ else:
108
+ offset_list = [(0., 0.) for _ in range(len(latents_all_list))]
109
+ composed_latents, foreground_indices = compose_latents(model_dict, latents_all_list, mask_tensor_list, num_inference_steps, overall_batch_size, height, width, **kwargs)
110
+ return composed_latents, foreground_indices, offset_list
111
+
112
+ def get_input_latents_list(model_dict, bg_seed, fg_seed_start, fg_blending_ratio, height, width, so_prompt_phrase_box_list=None, so_boxes=None, verbose=False):
113
+ """
114
+ Note: the returned input latents are scaled by `scheduler.init_noise_sigma`
115
+ """
116
+ unet, scheduler, dtype = model_dict.unet, model_dict.scheduler, model_dict.dtype
117
+
118
+ generator_bg = torch.manual_seed(bg_seed) # Seed generator to create the inital latent noise
119
+ latents_bg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_bg, dtype=dtype)
120
+
121
+ input_latents_list = []
122
+
123
+ if so_boxes is None:
124
+ # For compatibility
125
+ so_boxes = [item[-1] for item in so_prompt_phrase_box_list]
126
+
127
+ # change this changes the foreground initial noise
128
+ for idx, obj_box in enumerate(so_boxes):
129
+ H, W = height // 8, width // 8
130
+ fg_mask = utils.proportion_to_mask(obj_box, H, W)
131
+
132
+ if verbose:
133
+ plt.imshow(fg_mask.cpu().numpy())
134
+ plt.show()
135
+
136
+ fg_seed = fg_seed_start + idx
137
+ if fg_seed == bg_seed:
138
+ # We should have different seeds for foreground and background
139
+ fg_seed += 12345
140
+
141
+ generator_fg = torch.manual_seed(fg_seed)
142
+ latents_fg = get_unscaled_latents(batch_size=1, in_channels=unet.config.in_channels, height=height, width=width, generator=generator_fg, dtype=dtype)
143
+
144
+ input_latents = blend_latents(latents_bg, latents_fg, fg_mask, fg_blending_ratio=fg_blending_ratio)
145
+
146
+ input_latents = input_latents * scheduler.init_noise_sigma
147
+
148
+ input_latents_list.append(input_latents)
149
+
150
+ return input_latents_list, latents_bg
151
+
utils/parse.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import os
3
+ import json
4
+ from matplotlib.patches import Polygon
5
+ from matplotlib.collections import PatchCollection
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import cv2
9
+ import inflect
10
+
11
+ p = inflect.engine()
12
+
13
+ img_dir = "imgs"
14
+ bg_prompt_text = "Background prompt: "
15
+ # h, w
16
+ box_scale = (512, 512)
17
+ size = box_scale
18
+ size_h, size_w = size
19
+ print(f"Using box scale: {box_scale}")
20
+
21
+ def parse_input(text=None, no_input=False):
22
+ if not text:
23
+ if no_input:
24
+ return
25
+
26
+ text = input("Enter the response: ")
27
+ if "Objects: " in text:
28
+ text = text.split("Objects: ")[1]
29
+
30
+ text_split = text.split(bg_prompt_text)
31
+ if len(text_split) == 2:
32
+ gen_boxes, bg_prompt = text_split
33
+ elif len(text_split) == 1:
34
+ if no_input:
35
+ return
36
+ gen_boxes = text
37
+ bg_prompt = ""
38
+ while not bg_prompt:
39
+ # Ignore the empty lines in the response
40
+ bg_prompt = input("Enter the background prompt: ").strip()
41
+ if bg_prompt_text in bg_prompt:
42
+ bg_prompt = bg_prompt.split(bg_prompt_text)[1]
43
+ else:
44
+ raise ValueError(f"text: {text}")
45
+ try:
46
+ gen_boxes = ast.literal_eval(gen_boxes)
47
+ except SyntaxError as e:
48
+ # Sometimes the response is in plain text
49
+ if "No objects" in gen_boxes:
50
+ gen_boxes = []
51
+ else:
52
+ raise e
53
+ bg_prompt = bg_prompt.strip()
54
+
55
+ return gen_boxes, bg_prompt
56
+
57
+ def filter_boxes(gen_boxes, scale_boxes=True, ignore_background=True, max_scale=3):
58
+ if len(gen_boxes) == 0:
59
+ return []
60
+
61
+ box_dict_format = False
62
+ gen_boxes_new = []
63
+ for gen_box in gen_boxes:
64
+ if isinstance(gen_box, dict):
65
+ name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
66
+ box_dict_format = True
67
+ else:
68
+ name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
69
+ if bbox_w <= 0 or bbox_h <= 0:
70
+ # Empty boxes
71
+ continue
72
+ if ignore_background:
73
+ if (bbox_w >= size[1] and bbox_h >= size[0]) or bbox_x > size[1] or bbox_y > size[0]:
74
+ # Ignore the background boxes
75
+ continue
76
+ gen_boxes_new.append(gen_box)
77
+
78
+ gen_boxes = gen_boxes_new
79
+
80
+ if len(gen_boxes) == 0:
81
+ return []
82
+
83
+ filtered_gen_boxes = []
84
+ if box_dict_format:
85
+ # For compatibility
86
+ bbox_left_x_min = min([gen_box['bounding_box'][0] for gen_box in gen_boxes])
87
+ bbox_right_x_max = max([gen_box['bounding_box'][0] + gen_box['bounding_box'][2] for gen_box in gen_boxes])
88
+ bbox_top_y_min = min([gen_box['bounding_box'][1] for gen_box in gen_boxes])
89
+ bbox_bottom_y_max = max([gen_box['bounding_box'][1] + gen_box['bounding_box'][3] for gen_box in gen_boxes])
90
+ else:
91
+ bbox_left_x_min = min([gen_box[1][0] for gen_box in gen_boxes])
92
+ bbox_right_x_max = max([gen_box[1][0] + gen_box[1][2] for gen_box in gen_boxes])
93
+ bbox_top_y_min = min([gen_box[1][1] for gen_box in gen_boxes])
94
+ bbox_bottom_y_max = max([gen_box[1][1] + gen_box[1][3] for gen_box in gen_boxes])
95
+
96
+ # All boxes are empty
97
+ if (bbox_right_x_max - bbox_left_x_min) == 0:
98
+ return []
99
+
100
+ # Used if scale_boxes is True
101
+ shift = -bbox_left_x_min
102
+ scale = size_w / (bbox_right_x_max - bbox_left_x_min)
103
+
104
+ scale = min(scale, max_scale)
105
+
106
+ for gen_box in gen_boxes:
107
+ if box_dict_format:
108
+ name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box['name'], gen_box['bounding_box']
109
+ else:
110
+ name, [bbox_x, bbox_y, bbox_w, bbox_h] = gen_box
111
+
112
+ if scale_boxes:
113
+ # Vertical: move the boxes if out of bound
114
+ # Horizontal: move and scale the boxes so it spans the horizontal line
115
+
116
+ bbox_x = (bbox_x + shift) * scale
117
+ bbox_y = bbox_y * scale
118
+ bbox_w, bbox_h = bbox_w * scale, bbox_h * scale
119
+ # TODO: verify this makes the y center not moving
120
+ bbox_y_offset = 0
121
+ if bbox_top_y_min * scale + bbox_y_offset < 0:
122
+ bbox_y_offset -= bbox_top_y_min * scale
123
+ if bbox_bottom_y_max * scale + bbox_y_offset >= size_h:
124
+ bbox_y_offset -= bbox_bottom_y_max * scale - size_h
125
+ bbox_y += bbox_y_offset
126
+
127
+ if bbox_y < 0:
128
+ bbox_y, bbox_h = 0, bbox_h - bbox_y
129
+
130
+ name = name.rstrip(".")
131
+ bounding_box = (int(np.round(bbox_x)), int(np.round(bbox_y)), int(np.round(bbox_w)), int(np.round(bbox_h)))
132
+ if box_dict_format:
133
+ gen_box = {
134
+ 'name': name,
135
+ 'bounding_box': bounding_box
136
+ }
137
+ else:
138
+ gen_box = (name, bounding_box)
139
+
140
+ filtered_gen_boxes.append(gen_box)
141
+
142
+ return filtered_gen_boxes
143
+
144
+ def draw_boxes(anns):
145
+ ax = plt.gca()
146
+ ax.set_autoscale_on(False)
147
+ polygons = []
148
+ color = []
149
+ for ann in anns:
150
+ c = (np.random.random((1, 3))*0.6+0.4)
151
+ [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
152
+ poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
153
+ [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
154
+ np_poly = np.array(poly).reshape((4, 2))
155
+ polygons.append(Polygon(np_poly))
156
+ color.append(c)
157
+
158
+ # print(ann)
159
+ name = ann['name'] if 'name' in ann else str(ann['category_id'])
160
+ ax.text(bbox_x, bbox_y, name, style='italic',
161
+ bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
162
+
163
+ p = PatchCollection(polygons, facecolor='none',
164
+ edgecolors=color, linewidths=2)
165
+ ax.add_collection(p)
166
+
167
+
168
+ def show_boxes(gen_boxes, bg_prompt=None, ind=None, show=False):
169
+ if len(gen_boxes) == 0:
170
+ return
171
+
172
+ if isinstance(gen_boxes[0], dict):
173
+ anns = [{'name': gen_box['name'], 'bbox': gen_box['bounding_box']}
174
+ for gen_box in gen_boxes]
175
+ else:
176
+ anns = [{'name': gen_box[0], 'bbox': gen_box[1]} for gen_box in gen_boxes]
177
+
178
+ # White background (to allow line to show on the edge)
179
+ I = np.ones((size[0]+4, size[1]+4, 3), dtype=np.uint8) * 255
180
+
181
+ plt.imshow(I)
182
+ plt.axis('off')
183
+
184
+ if bg_prompt is not None:
185
+ ax = plt.gca()
186
+ ax.text(0, 0, bg_prompt, style='italic',
187
+ bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
188
+
189
+ c = (np.zeros((1, 3)))
190
+ [bbox_x, bbox_y, bbox_w, bbox_h] = (0, 0, size[1], size[0])
191
+ poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h],
192
+ [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
193
+ np_poly = np.array(poly).reshape((4, 2))
194
+ polygons = [Polygon(np_poly)]
195
+ color = [c]
196
+ p = PatchCollection(polygons, facecolor='none',
197
+ edgecolors=color, linewidths=2)
198
+ ax.add_collection(p)
199
+
200
+ draw_boxes(anns)
201
+ if show:
202
+ plt.show()
203
+ else:
204
+ print("Saved to", f"{img_dir}/boxes.png", f"ind: {ind}")
205
+ if ind is not None:
206
+ plt.savefig(f"{img_dir}/boxes_{ind}.png")
207
+ plt.savefig(f"{img_dir}/boxes.png")
208
+
209
+
210
+ def show_masks(masks):
211
+ masks_to_show = np.zeros((*size, 3), dtype=np.float32)
212
+ for mask in masks:
213
+ c = (np.random.random((3,))*0.6+0.4)
214
+
215
+ masks_to_show += mask[..., None] * c[None, None, :]
216
+ plt.imshow(masks_to_show)
217
+ plt.savefig(f"{img_dir}/masks.png")
218
+ plt.show()
219
+ plt.clf()
220
+
221
+ def convert_box(box, height, width):
222
+ # box: x, y, w, h (in 512 format) -> x_min, y_min, x_max, y_max
223
+ x_min, y_min = box[0] / width, box[1] / height
224
+ w_box, h_box = box[2] / width, box[3] / height
225
+
226
+ x_max, y_max = x_min + w_box, y_min + h_box
227
+
228
+ return x_min, y_min, x_max, y_max
229
+
230
+ def convert_spec(spec, height, width, include_counts=True, verbose=False):
231
+ # Infer from spec
232
+ prompt, gen_boxes, bg_prompt = spec['prompt'], spec['gen_boxes'], spec['bg_prompt']
233
+
234
+ # This ensures the same objects appear together because flattened `overall_phrases_bboxes` should EXACTLY correspond to `so_prompt_phrase_box_list`.
235
+ gen_boxes = sorted(gen_boxes, key=lambda gen_box: gen_box[0])
236
+
237
+ gen_boxes = [(name, convert_box(box, height=height, width=width)) for name, box in gen_boxes]
238
+
239
+ # NOTE: so phrase should include all the words associated to the object (otherwise "an orange dog" may be recognized as "an orange" by the model generating the background).
240
+ # so word should have one token that includes the word to transfer cross attention (the object name).
241
+ # Currently using the last word of the object name as word.
242
+ if bg_prompt:
243
+ so_prompt_phrase_word_box_list = [(f"{bg_prompt} with {name}", name, name.split(" ")[-1], box) for name, box in gen_boxes]
244
+ else:
245
+ so_prompt_phrase_word_box_list = [(f"{name}", name, name.split(" ")[-1], box) for name, box in gen_boxes]
246
+
247
+ objects = [gen_box[0] for gen_box in gen_boxes]
248
+
249
+ objects_unique, objects_count = np.unique(objects, return_counts=True)
250
+
251
+ num_total_matched_boxes = 0
252
+ overall_phrases_words_bboxes = []
253
+ for ind, object_name in enumerate(objects_unique):
254
+ bboxes = [box for name, box in gen_boxes if name == object_name]
255
+
256
+ if objects_count[ind] > 1:
257
+ phrase = p.plural_noun(object_name.replace("an ", "").replace("a ", ""))
258
+ if include_counts:
259
+ phrase = p.number_to_words(objects_count[ind]) + " " + phrase
260
+ else:
261
+ phrase = object_name
262
+ # Currently using the last word of the phrase as word.
263
+ word = phrase.split(' ')[-1]
264
+
265
+ num_total_matched_boxes += len(bboxes)
266
+ overall_phrases_words_bboxes.append((phrase, word, bboxes))
267
+
268
+ assert num_total_matched_boxes == len(gen_boxes), f"{num_total_matched_boxes} != {len(gen_boxes)}"
269
+
270
+ objects_str = ", ".join([phrase for phrase, _, _ in overall_phrases_words_bboxes])
271
+ if objects_str:
272
+ if bg_prompt:
273
+ overall_prompt = f"{bg_prompt} with {objects_str}"
274
+ else:
275
+ overall_prompt = objects_str
276
+ else:
277
+ overall_prompt = bg_prompt
278
+
279
+ if verbose:
280
+ print("so_prompt_phrase_word_box_list:", so_prompt_phrase_word_box_list)
281
+ print("overall_prompt:", overall_prompt)
282
+ print("overall_phrases_words_bboxes:", overall_phrases_words_bboxes)
283
+
284
+ return so_prompt_phrase_word_box_list, overall_prompt, overall_phrases_words_bboxes
utils/utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import ImageDraw
3
+ import numpy as np
4
+ import os
5
+ import gc
6
+
7
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ def draw_box(pil_img, bboxes, phrases):
10
+ draw = ImageDraw.Draw(pil_img)
11
+ # font = ImageFont.truetype('./FreeMono.ttf', 25)
12
+
13
+ for obj_bbox, phrase in zip(bboxes, phrases):
14
+ x_0, y_0, x_1, y_1 = obj_bbox[0], obj_bbox[1], obj_bbox[2], obj_bbox[3]
15
+ draw.rectangle([int(x_0 * 512), int(y_0 * 512), int(x_1 * 512), int(y_1 * 512)], outline='red', width=5)
16
+ draw.text((int(x_0 * 512) + 5, int(y_0 * 512) + 5), phrase, font=None, fill=(255, 0, 0))
17
+
18
+ return pil_img
19
+
20
+ def get_centered_box(box, horizontal_center_only=True):
21
+ x_min, y_min, x_max, y_max = box
22
+ w = x_max - x_min
23
+
24
+ if horizontal_center_only:
25
+ return [0.5 - w/2, y_min, 0.5 + w/2, y_max]
26
+
27
+ h = y_max - y_min
28
+
29
+ return [0.5 - w/2, 0.5 - h/2, 0.5 + w/2, 0.5 + h/2]
30
+
31
+ # NOTE: this changes the behavior of the function
32
+ def proportion_to_mask(obj_box, H, W, use_legacy=False, return_np=False):
33
+ x_min, y_min, x_max, y_max = scale_proportion(obj_box, H, W, use_legacy)
34
+ if return_np:
35
+ mask = np.zeros((H, W))
36
+ else:
37
+ mask = torch.zeros(H, W).to(torch_device)
38
+ mask[y_min: y_max, x_min: x_max] = 1.
39
+
40
+ return mask
41
+
42
+ def scale_proportion(obj_box, H, W, use_legacy=False):
43
+ if use_legacy:
44
+ # Bias towards the top-left corner
45
+ x_min, y_min, x_max, y_max = int(obj_box[0] * W), int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
46
+ else:
47
+ # Separately rounding box_w and box_h to allow shift invariant box sizes. Otherwise box sizes may change when both coordinates being rounded end with ".5".
48
+ x_min, y_min = round(obj_box[0] * W), round(obj_box[1] * H)
49
+ box_w, box_h = round((obj_box[2] - obj_box[0]) * W), round((obj_box[3] - obj_box[1]) * H)
50
+ x_max, y_max = x_min + box_w, y_min + box_h
51
+
52
+ x_min, y_min = max(x_min, 0), max(y_min, 0)
53
+ x_max, y_max = min(x_max, W), min(y_max, H)
54
+
55
+ return x_min, y_min, x_max, y_max
56
+
57
+ def binary_mask_to_box(mask, enlarge_box_by_one=True, w_scale=1, h_scale=1):
58
+ if isinstance(mask, torch.Tensor):
59
+ mask_loc = torch.where(mask)
60
+ else:
61
+ mask_loc = np.where(mask)
62
+ height, width = mask.shape
63
+ if len(mask_loc) == 0:
64
+ raise ValueError('The mask is empty')
65
+ if enlarge_box_by_one:
66
+ ymin, ymax = max(min(mask_loc[0]) - 1, 0), min(max(mask_loc[0]) + 1, height)
67
+ xmin, xmax = max(min(mask_loc[1]) - 1, 0), min(max(mask_loc[1]) + 1, width)
68
+ else:
69
+ ymin, ymax = min(mask_loc[0]), max(mask_loc[0])
70
+ xmin, xmax = min(mask_loc[1]), max(mask_loc[1])
71
+ box = [xmin * w_scale, ymin * h_scale, xmax * w_scale, ymax * h_scale]
72
+
73
+ return box
74
+
75
+ def binary_mask_to_box_mask(mask, to_device=True):
76
+ box = binary_mask_to_box(mask)
77
+ x_min, y_min, x_max, y_max = box
78
+
79
+ H, W = mask.shape
80
+ mask = torch.zeros(H, W)
81
+ if to_device:
82
+ mask = mask.to(torch_device)
83
+ mask[y_min: y_max+1, x_min: x_max+1] = 1.
84
+
85
+ return mask
86
+
87
+ def binary_mask_to_center(mask, normalize=False):
88
+ """
89
+ This computes the mass center of the mask.
90
+ normalize: the coords range from 0 to 1
91
+
92
+ Reference: https://stackoverflow.com/a/66184125
93
+ """
94
+ h, w = mask.shape
95
+
96
+ total = mask.sum()
97
+ if isinstance(mask, torch.Tensor):
98
+ x_coord = ((mask.sum(dim=0) @ torch.arange(w)) / total).item()
99
+ y_coord = ((mask.sum(dim=1) @ torch.arange(h)) / total).item()
100
+ else:
101
+ x_coord = (mask.sum(axis=0) @ np.arange(w)) / total
102
+ y_coord = (mask.sum(axis=1) @ np.arange(h)) / total
103
+
104
+ if normalize:
105
+ x_coord, y_coord = x_coord / w, y_coord / h
106
+ return x_coord, y_coord
107
+
108
+
109
+ def iou(mask, masks, eps=1e-6):
110
+ # mask: [h, w], masks: [n, h, w]
111
+ mask = mask[None].astype(bool)
112
+ masks = masks.astype(bool)
113
+ i = (mask & masks).sum(axis=(1,2))
114
+ u = (mask | masks).sum(axis=(1,2))
115
+
116
+ return i / (u + eps)
117
+
118
+ def free_memory():
119
+ gc.collect()
120
+ torch.cuda.empty_cache()
121
+
122
+ def expand_overall_bboxes(overall_bboxes):
123
+ """
124
+ Expand overall bboxes from a 3d list to 2d list:
125
+ Input: [[box 1 for phrase 1, box 2 for phrase 1], ...]
126
+ Output: [box 1, box 2, ...]
127
+ """
128
+ return sum(overall_bboxes, start=[])
129
+
130
+ def shift_tensor(tensor, x_offset, y_offset, base_w=8, base_h=8, offset_normalized=False, ignore_last_dim=False):
131
+ """base_w and base_h: make sure the shift is aligned in the latent and multiple levels of cross attention"""
132
+ if ignore_last_dim:
133
+ tensor_h, tensor_w = tensor.shape[-3:-1]
134
+ else:
135
+ tensor_h, tensor_w = tensor.shape[-2:]
136
+ if offset_normalized:
137
+ assert tensor_h % base_h == 0 and tensor_w % base_w == 0, f"{tensor_h, tensor_w} is not a multiple of {base_h, base_w}"
138
+ scale_from_base_h, scale_from_base_w = tensor_h // base_h, tensor_w // base_w
139
+ x_offset, y_offset = round(x_offset * base_w) * scale_from_base_w, round(y_offset * base_h) * scale_from_base_h
140
+ new_tensor = torch.zeros_like(tensor)
141
+
142
+ overlap_w = tensor_w - abs(x_offset)
143
+ overlap_h = tensor_h - abs(y_offset)
144
+
145
+ if y_offset >= 0:
146
+ y_src_start = 0
147
+ y_dest_start = y_offset
148
+ else:
149
+ y_src_start = -y_offset
150
+ y_dest_start = 0
151
+
152
+ if x_offset >= 0:
153
+ x_src_start = 0
154
+ x_dest_start = x_offset
155
+ else:
156
+ x_src_start = -x_offset
157
+ x_dest_start = 0
158
+
159
+ if ignore_last_dim:
160
+ # For cross attention maps, the third to last and the second to last are the 2D dimensions after unflatten.
161
+ new_tensor[..., y_dest_start:y_dest_start+overlap_h, x_dest_start:x_dest_start+overlap_w, :] = tensor[..., y_src_start:y_src_start+overlap_h, x_src_start:x_src_start+overlap_w, :]
162
+ else:
163
+ new_tensor[..., y_dest_start:y_dest_start+overlap_h, x_dest_start:x_dest_start+overlap_w] = tensor[..., y_src_start:y_src_start+overlap_h, x_src_start:x_src_start+overlap_w]
164
+
165
+ return new_tensor