Songwei Ge commited on
Commit
4c022fe
·
1 Parent(s): 14a857e
app.py CHANGED
@@ -1,10 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- HTML = "<!-- Include stylesheet -->\n<link href=\"https://cdn.quilljs.com/1.3.6/quill.snow.css\" rel=\"stylesheet\">\n\n<!-- Create the editor container -->\n<div id=\"editor\">\n <p>Hello World!</p>\n <p>Some initial <strong>bold</strong> text</p>\n <p><br></p>\n</div>\n\n<!-- Include the Quill library -->\n<script src=\"https://cdn.quilljs.com/1.3.6/quill.js\"></script>\n\n<!-- Initialize Quill editor -->\n<script>\n var quill = new Quill('#editor', {\n theme: 'snow'\n });\n</script>"
 
5
 
6
- def greet(name):
7
- return HTML, "Hello " + name + "!!"
8
 
9
- iface = gr.Interface(greet, gr.Textbox(placeholder="Enter sentence here..."), ["html", "text"])
10
- iface.launch()
 
1
+ import math
2
+ import random
3
+ import os
4
+ import json
5
+ import time
6
+ import argparse
7
+ import imageio
8
+ import torch
9
+ import numpy as np
10
+ from torchvision import transforms
11
+
12
+ from models.region_diffusion import RegionDiffusion
13
+ from utils.attention_utils import get_token_maps
14
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
15
+ get_attention_control_input, get_gradient_guidance_input
16
+
17
+
18
  import gradio as gr
19
+ from PIL import Image, ImageOps
20
+
21
+
22
+ help_text = """
23
+ Instructions placeholder.
24
+ """
25
+
26
+
27
+ example_instructions = [
28
+ "Make it a picasso painting",
29
+ "as if it were by modigliani",
30
+ "convert to a bronze statue",
31
+ "Turn it into an anime.",
32
+ "have it look like a graphic novel",
33
+ "make him gain weight",
34
+ "what would he look like bald?",
35
+ "Have him smile",
36
+ "Put him in a cocktail party.",
37
+ "move him at the beach.",
38
+ "add dramatic lighting",
39
+ "Convert to black and white",
40
+ "What if it were snowing?",
41
+ "Give him a leather jacket",
42
+ "Turn him into a cyborg!",
43
+ "make him wear a beanie",
44
+ ]
45
+
46
+ def main():
47
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ model = RegionDiffusion(device)
49
+
50
+ def generate(
51
+ text_input: str,
52
+ negative_text: str,
53
+ height: int,
54
+ width: int,
55
+ seed: int,
56
+ steps: int,
57
+ guidance_weight: float,
58
+ ):
59
+ run_dir = 'results/'
60
+ # Load region diffusion model.
61
+ steps = 41 if not steps else steps
62
+ guidance_weight = 8.5 if not guidance_weight else guidance_weight
63
+
64
+ # parse json to span attributes
65
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
66
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
67
+ text_input)
68
+
69
+ # create control input for region diffusion
70
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
71
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
72
+ footnote_target_tokens, color_text_prompts, color_names)
73
+
74
+ # create control input for cross attention
75
+ text_format_dict = get_attention_control_input(
76
+ model, base_tokens, size_text_prompts_and_sizes)
77
+
78
+ # create control input for region guidance
79
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
80
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict)
81
+
82
+ seed_everything(seed)
83
+
84
+ # get token maps from plain text to image generation.
85
+ begin_time = time.time()
86
+ if model.attention_maps is None:
87
+ model.register_evaluation_hooks()
88
+ else:
89
+ model.reset_attention_maps()
90
+ plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
91
+ height=height, width=width, num_inference_steps=steps,
92
+ guidance_scale=guidance_weight)
93
+ print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
94
+ color_obj_masks = get_token_maps(
95
+ model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
96
+ model.masks = get_token_maps(
97
+ model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
98
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
99
+ interpolation=transforms.InterpolationMode.BICUBIC,
100
+ antialias=True)
101
+ for color_obj_mask in color_obj_masks]
102
+ text_format_dict['color_obj_atten'] = color_obj_masks
103
+ model.remove_evaluation_hooks()
104
+
105
+ # generate image from rich text
106
+ begin_time = time.time()
107
+ seed_everything(seed)
108
+ rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
109
+ height=height, width=width, num_inference_steps=steps,
110
+ guidance_scale=guidance_weight, use_grad_guidance=use_grad_guidance,
111
+ text_format_dict=text_format_dict)
112
+ print('time lapses to generate image from rich text: %.4f' %
113
+ (time.time()-begin_time))
114
+ return [plain_img[0], rich_img[0]]
115
+
116
+ with gr.Blocks() as demo:
117
+ gr.HTML("""<h1 style="font-weight: 900; margin-bottom: 7px;">Expressive Text-to-Image Generation with Rich Text</h1>
118
+ <p> Visit our <a href="https://rich-text-to-image.github.io/rich-text-to-json.html">rich-text-to-json interface</a> to generate rich-text JSON input.<p/>""")
119
+ with gr.Row():
120
+ with gr.Column():
121
+ text_input = gr.Textbox(
122
+ label='Rich-text JSON Input',
123
+ max_lines=1,
124
+ placeholder='Example: \'{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}\'')
125
+ negative_prompt = gr.Textbox(
126
+ label='Negative Prompt',
127
+ max_lines=1,
128
+ placeholder='')
129
+ seed = gr.Slider(label='Seed',
130
+ minimum=0,
131
+ maximum=100000,
132
+ step=1,
133
+ value=6)
134
+ with gr.Accordion('Other Parameters', open=False):
135
+ steps = gr.Slider(label='Number of Steps',
136
+ minimum=0,
137
+ maximum=500,
138
+ step=1,
139
+ value=41)
140
+ guidance_weight = gr.Slider(label='CFG weight',
141
+ minimum=0,
142
+ maximum=50,
143
+ step=0.1,
144
+ value=8.5)
145
+ width = gr.Dropdown(choices=[512, 768, 896],
146
+ value=512,
147
+ label='Width',
148
+ visible=True)
149
+ height = gr.Dropdown(choices=[512, 768, 896],
150
+ value=512,
151
+ label='height',
152
+ visible=True)
153
+
154
+ with gr.Row():
155
+ with gr.Column(scale=1, min_width=100):
156
+ generate_button = gr.Button("Generate")
157
+
158
+ with gr.Column():
159
+ result = gr.Image(label='Result')
160
+ token_map = gr.Image(label='TokenMap')
161
+
162
+ with gr.Row():
163
+ examples = [
164
+ [
165
+ '{"ops":[{"insert":"a Gothic "},{"attributes":{"color":"#b26b00"},"insert":"church"},{"insert":" in a the sunset with a beautiful landscape in the background.\n"}]}',
166
+ '',
167
+ 512,
168
+ 512,
169
+ 6,
170
+ ],
171
+ [
172
+ '{"ops": [{"insert": "A pizza with "}, {"attributes": {"size": "50px"}, "insert": "pineapples"}, {"insert": ", pepperonis, and mushrooms on the top, 4k, photorealistic\n"}]}',
173
+ 'blurry, art, painting, rendering, drawing, sketch, ugly, duplicate, morbid, mutilated, mutated, deformed, disfigured low quality, worst quality',
174
+ 768,
175
+ 896,
176
+ 6,
177
+ ],
178
+ [
179
+ '{"ops":[{"insert":"a "},{"attributes":{"font":"mirza"},"insert":"beautiful garden"},{"insert":" with a "},{"attributes":{"font":"roboto"},"insert":"snow mountain in the background"},{"insert":"\n"}]}',
180
+ '',
181
+ 512,
182
+ 512,
183
+ 3,
184
+ ],
185
+ [
186
+ '{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. Palm trees in the background.\n"}]}',
187
+ '',
188
+ 512,
189
+ 512,
190
+ 6,
191
+ ],
192
+ ]
193
+ gr.Examples(examples=examples,
194
+ inputs=[
195
+ text_input,
196
+ negative_prompt,
197
+ height,
198
+ width,
199
+ seed,
200
+ ],
201
+ outputs=[
202
+ result,
203
+ token_map,
204
+ ],
205
+ fn=generate,
206
+ # cache_examples=True,
207
+ examples_per_page=20)
208
 
209
+ generate_button.click(
210
+ fn=generate,
211
+ inputs=[
212
+ text_input,
213
+ negative_prompt,
214
+ height,
215
+ width,
216
+ seed,
217
+ steps,
218
+ guidance_weight,
219
+ ],
220
+ outputs=[result, token_map],
221
+ )
222
 
223
+ demo.queue(concurrency_count=1)
224
+ demo.launch(share=False)
225
 
 
 
226
 
227
+ if __name__ == "__main__":
228
+ main()
models/attention.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ import warnings
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from torch import nn
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
26
+ from diffusers.utils import BaseOutput
27
+ from diffusers.utils.import_utils import is_xformers_available
28
+
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ Args:
34
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
35
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
36
+ for the unnoised latent pixels.
37
+ """
38
+
39
+ sample: torch.FloatTensor
40
+
41
+
42
+ if is_xformers_available():
43
+ import xformers
44
+ import xformers.ops
45
+ else:
46
+ xformers = None
47
+
48
+
49
+ class Transformer2DModel(ModelMixin, ConfigMixin):
50
+ """
51
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
52
+ embeddings) inputs.
53
+
54
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
55
+ transformer action. Finally, reshape to image.
56
+
57
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
58
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
59
+ classes of unnoised image.
60
+
61
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
62
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
63
+
64
+ Parameters:
65
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
66
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
67
+ in_channels (`int`, *optional*):
68
+ Pass if the input is continuous. The number of channels in the input and output.
69
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
70
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
71
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
72
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
73
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
74
+ `ImagePositionalEmbeddings`.
75
+ num_vector_embeds (`int`, *optional*):
76
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
77
+ Includes the class for the masked latent pixel.
78
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
79
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
80
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
81
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
82
+ up to but not more than steps than `num_embeds_ada_norm`.
83
+ attention_bias (`bool`, *optional*):
84
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
85
+ """
86
+
87
+ @register_to_config
88
+ def __init__(
89
+ self,
90
+ num_attention_heads: int = 16,
91
+ attention_head_dim: int = 88,
92
+ in_channels: Optional[int] = None,
93
+ num_layers: int = 1,
94
+ dropout: float = 0.0,
95
+ norm_num_groups: int = 32,
96
+ cross_attention_dim: Optional[int] = None,
97
+ attention_bias: bool = False,
98
+ sample_size: Optional[int] = None,
99
+ num_vector_embeds: Optional[int] = None,
100
+ activation_fn: str = "geglu",
101
+ num_embeds_ada_norm: Optional[int] = None,
102
+ use_linear_projection: bool = False,
103
+ only_cross_attention: bool = False,
104
+ ):
105
+ super().__init__()
106
+ self.use_linear_projection = use_linear_projection
107
+ self.num_attention_heads = num_attention_heads
108
+ self.attention_head_dim = attention_head_dim
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+
111
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
+ # Define whether input is continuous or discrete depending on configuration
113
+ self.is_input_continuous = in_channels is not None
114
+ self.is_input_vectorized = num_vector_embeds is not None
115
+
116
+ if self.is_input_continuous and self.is_input_vectorized:
117
+ raise ValueError(
118
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
119
+ " sure that either `in_channels` or `num_vector_embeds` is None."
120
+ )
121
+ elif not self.is_input_continuous and not self.is_input_vectorized:
122
+ raise ValueError(
123
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
124
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
125
+ )
126
+
127
+ # 2. Define input layers
128
+ if self.is_input_continuous:
129
+ self.in_channels = in_channels
130
+
131
+ self.norm = torch.nn.GroupNorm(
132
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
133
+ if use_linear_projection:
134
+ self.proj_in = nn.Linear(in_channels, inner_dim)
135
+ else:
136
+ self.proj_in = nn.Conv2d(
137
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
138
+ elif self.is_input_vectorized:
139
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
140
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
141
+
142
+ self.height = sample_size
143
+ self.width = sample_size
144
+ self.num_vector_embeds = num_vector_embeds
145
+ self.num_latent_pixels = self.height * self.width
146
+
147
+ self.latent_image_embedding = ImagePositionalEmbeddings(
148
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
149
+ )
150
+
151
+ # 3. Define transformers blocks
152
+ self.transformer_blocks = nn.ModuleList(
153
+ [
154
+ BasicTransformerBlock(
155
+ inner_dim,
156
+ num_attention_heads,
157
+ attention_head_dim,
158
+ dropout=dropout,
159
+ cross_attention_dim=cross_attention_dim,
160
+ activation_fn=activation_fn,
161
+ num_embeds_ada_norm=num_embeds_ada_norm,
162
+ attention_bias=attention_bias,
163
+ only_cross_attention=only_cross_attention,
164
+ )
165
+ for d in range(num_layers)
166
+ ]
167
+ )
168
+
169
+ # 4. Define output layers
170
+ if self.is_input_continuous:
171
+ if use_linear_projection:
172
+ self.proj_out = nn.Linear(in_channels, inner_dim)
173
+ else:
174
+ self.proj_out = nn.Conv2d(
175
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
176
+ elif self.is_input_vectorized:
177
+ self.norm_out = nn.LayerNorm(inner_dim)
178
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
179
+
180
+ def _set_attention_slice(self, slice_size):
181
+ for block in self.transformer_blocks:
182
+ block._set_attention_slice(slice_size)
183
+
184
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None,
185
+ text_format_dict={}, return_dict: bool = True):
186
+ """
187
+ Args:
188
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
189
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
190
+ hidden_states
191
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
192
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
193
+ self-attention.
194
+ timestep ( `torch.long`, *optional*):
195
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
196
+ return_dict (`bool`, *optional*, defaults to `True`):
197
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
198
+
199
+ Returns:
200
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
201
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
202
+ tensor.
203
+ """
204
+ # 1. Input
205
+ if self.is_input_continuous:
206
+ batch, channel, height, weight = hidden_states.shape
207
+ residual = hidden_states
208
+
209
+ hidden_states = self.norm(hidden_states)
210
+ if not self.use_linear_projection:
211
+ hidden_states = self.proj_in(hidden_states)
212
+ inner_dim = hidden_states.shape[1]
213
+ hidden_states = hidden_states.permute(
214
+ 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
215
+ else:
216
+ inner_dim = hidden_states.shape[1]
217
+ hidden_states = hidden_states.permute(
218
+ 0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
219
+ hidden_states = self.proj_in(hidden_states)
220
+ elif self.is_input_vectorized:
221
+ hidden_states = self.latent_image_embedding(hidden_states)
222
+
223
+ # 2. Blocks
224
+ for block in self.transformer_blocks:
225
+ hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep,
226
+ text_format_dict=text_format_dict)
227
+
228
+ # 3. Output
229
+ if self.is_input_continuous:
230
+ if not self.use_linear_projection:
231
+ hidden_states = (
232
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(
233
+ 0, 3, 1, 2).contiguous()
234
+ )
235
+ hidden_states = self.proj_out(hidden_states)
236
+ else:
237
+ hidden_states = self.proj_out(hidden_states)
238
+ hidden_states = (
239
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(
240
+ 0, 3, 1, 2).contiguous()
241
+ )
242
+
243
+ output = hidden_states + residual
244
+ elif self.is_input_vectorized:
245
+ hidden_states = self.norm_out(hidden_states)
246
+ logits = self.out(hidden_states)
247
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
248
+ logits = logits.permute(0, 2, 1)
249
+
250
+ # log(p(x_0))
251
+ output = F.log_softmax(logits.double(), dim=1).float()
252
+
253
+ if not return_dict:
254
+ return (output,)
255
+
256
+ return Transformer2DModelOutput(sample=output)
257
+
258
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
259
+ for block in self.transformer_blocks:
260
+ block._set_use_memory_efficient_attention_xformers(
261
+ use_memory_efficient_attention_xformers)
262
+
263
+
264
+ class AttentionBlock(nn.Module):
265
+ """
266
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
267
+ to the N-d case.
268
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
269
+ Uses three q, k, v linear layers to compute attention.
270
+
271
+ Parameters:
272
+ channels (`int`): The number of channels in the input and output.
273
+ num_head_channels (`int`, *optional*):
274
+ The number of channels in each head. If None, then `num_heads` = 1.
275
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
276
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
277
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
278
+ """
279
+
280
+ def __init__(
281
+ self,
282
+ channels: int,
283
+ num_head_channels: Optional[int] = None,
284
+ norm_num_groups: int = 32,
285
+ rescale_output_factor: float = 1.0,
286
+ eps: float = 1e-5,
287
+ ):
288
+ super().__init__()
289
+ self.channels = channels
290
+
291
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
292
+ self.num_head_size = num_head_channels
293
+ self.group_norm = nn.GroupNorm(
294
+ num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
295
+
296
+ # define q,k,v as linear layers
297
+ self.query = nn.Linear(channels, channels)
298
+ self.key = nn.Linear(channels, channels)
299
+ self.value = nn.Linear(channels, channels)
300
+
301
+ self.rescale_output_factor = rescale_output_factor
302
+ self.proj_attn = nn.Linear(channels, channels, 1)
303
+
304
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
305
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
306
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
307
+ new_projection = projection.view(
308
+ new_projection_shape).permute(0, 2, 1, 3)
309
+ return new_projection
310
+
311
+ def forward(self, hidden_states):
312
+ residual = hidden_states
313
+ batch, channel, height, width = hidden_states.shape
314
+
315
+ # norm
316
+ hidden_states = self.group_norm(hidden_states)
317
+
318
+ hidden_states = hidden_states.view(
319
+ batch, channel, height * width).transpose(1, 2)
320
+
321
+ # proj to q, k, v
322
+ query_proj = self.query(hidden_states)
323
+ key_proj = self.key(hidden_states)
324
+ value_proj = self.value(hidden_states)
325
+
326
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
327
+
328
+ # get scores
329
+ if self.num_heads > 1:
330
+ query_states = self.transpose_for_scores(query_proj)
331
+ key_states = self.transpose_for_scores(key_proj)
332
+ value_states = self.transpose_for_scores(value_proj)
333
+
334
+ # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
335
+ # or reformulate this into a 3D problem?
336
+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
337
+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
338
+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
339
+ attention_scores = torch.matmul(
340
+ query_states, key_states.transpose(-1, -2)) * scale
341
+ else:
342
+ query_states, key_states, value_states = query_proj, key_proj, value_proj
343
+
344
+ attention_scores = torch.baddbmm(
345
+ torch.empty(
346
+ query_states.shape[0],
347
+ query_states.shape[1],
348
+ key_states.shape[1],
349
+ dtype=query_states.dtype,
350
+ device=query_states.device,
351
+ ),
352
+ query_states,
353
+ key_states.transpose(-1, -2),
354
+ beta=0,
355
+ alpha=scale,
356
+ )
357
+
358
+ attention_probs = torch.softmax(
359
+ attention_scores.float(), dim=-1).type(attention_scores.dtype)
360
+
361
+ # compute attention output
362
+ if self.num_heads > 1:
363
+ # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
364
+ # or reformulate this into a 3D problem?
365
+ # TODO: measure whether on MPS device it would be faster to do this matmul via einsum
366
+ # as some matmuls can be 1.94x slower than an equivalent einsum on MPS
367
+ # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
368
+ hidden_states = torch.matmul(attention_probs, value_states)
369
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
370
+ new_hidden_states_shape = hidden_states.size()[
371
+ :-2] + (self.channels,)
372
+ hidden_states = hidden_states.view(new_hidden_states_shape)
373
+ else:
374
+ hidden_states = torch.bmm(attention_probs, value_states)
375
+
376
+ # compute next hidden_states
377
+ hidden_states = self.proj_attn(hidden_states)
378
+ hidden_states = hidden_states.transpose(
379
+ -1, -2).reshape(batch, channel, height, width)
380
+
381
+ # res connect and rescale
382
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
383
+ return hidden_states
384
+
385
+
386
+ class BasicTransformerBlock(nn.Module):
387
+ r"""
388
+ A basic Transformer block.
389
+
390
+ Parameters:
391
+ dim (`int`): The number of channels in the input and output.
392
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
393
+ attention_head_dim (`int`): The number of channels in each head.
394
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
395
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
396
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
397
+ num_embeds_ada_norm (:
398
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
399
+ attention_bias (:
400
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
401
+ """
402
+
403
+ def __init__(
404
+ self,
405
+ dim: int,
406
+ num_attention_heads: int,
407
+ attention_head_dim: int,
408
+ dropout=0.0,
409
+ cross_attention_dim: Optional[int] = None,
410
+ activation_fn: str = "geglu",
411
+ num_embeds_ada_norm: Optional[int] = None,
412
+ attention_bias: bool = False,
413
+ only_cross_attention: bool = False,
414
+ ):
415
+ super().__init__()
416
+ self.only_cross_attention = only_cross_attention
417
+ self.attn1 = CrossAttention(
418
+ query_dim=dim,
419
+ heads=num_attention_heads,
420
+ dim_head=attention_head_dim,
421
+ dropout=dropout,
422
+ bias=attention_bias,
423
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
424
+ ) # is a self-attention
425
+ self.ff = FeedForward(dim, dropout=dropout,
426
+ activation_fn=activation_fn)
427
+ self.attn2 = CrossAttention(
428
+ query_dim=dim,
429
+ cross_attention_dim=cross_attention_dim,
430
+ heads=num_attention_heads,
431
+ dim_head=attention_head_dim,
432
+ dropout=dropout,
433
+ bias=attention_bias,
434
+ ) # is self-attn if context is none
435
+
436
+ # layer norms
437
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
438
+ if self.use_ada_layer_norm:
439
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
440
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
441
+ else:
442
+ self.norm1 = nn.LayerNorm(dim)
443
+ self.norm2 = nn.LayerNorm(dim)
444
+ self.norm3 = nn.LayerNorm(dim)
445
+
446
+ # if xformers is installed try to use memory_efficient_attention by default
447
+ if is_xformers_available():
448
+ try:
449
+ self._set_use_memory_efficient_attention_xformers(True)
450
+ except Exception as e:
451
+ warnings.warn(
452
+ "Could not enable memory efficient attention. Make sure xformers is installed"
453
+ f" correctly and a GPU is available: {e}"
454
+ )
455
+
456
+ def _set_attention_slice(self, slice_size):
457
+ self.attn1._slice_size = slice_size
458
+ self.attn2._slice_size = slice_size
459
+
460
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
461
+ if not is_xformers_available():
462
+ print("Here is how to install it")
463
+ raise ModuleNotFoundError(
464
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
465
+ " xformers",
466
+ name="xformers",
467
+ )
468
+ elif not torch.cuda.is_available():
469
+ raise ValueError(
470
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
471
+ " available for GPU "
472
+ )
473
+ else:
474
+ try:
475
+ # Make sure we can run the memory efficient attention
476
+ _ = xformers.ops.memory_efficient_attention(
477
+ torch.randn((1, 2, 40), device="cuda"),
478
+ torch.randn((1, 2, 40), device="cuda"),
479
+ torch.randn((1, 2, 40), device="cuda"),
480
+ )
481
+ except Exception as e:
482
+ raise e
483
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
484
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
485
+
486
+ def forward(self, hidden_states, context=None, timestep=None, text_format_dict={}):
487
+ # 1. Self-Attention
488
+ norm_hidden_states = (
489
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(
490
+ hidden_states)
491
+ )
492
+
493
+ if self.only_cross_attention:
494
+ attn_out, _ = self.attn1(
495
+ norm_hidden_states, context, text_format_dict=text_format_dict) + hidden_states
496
+ hidden_states = attn_out + hidden_states
497
+ else:
498
+ attn_out, _ = self.attn1(norm_hidden_states)
499
+ hidden_states = attn_out + hidden_states
500
+
501
+ # 2. Cross-Attention
502
+ norm_hidden_states = (
503
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(
504
+ hidden_states)
505
+ )
506
+ attn_out, _ = self.attn2(
507
+ norm_hidden_states, context=context, text_format_dict=text_format_dict)
508
+ hidden_states = attn_out + hidden_states
509
+
510
+ # 3. Feed-forward
511
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ class CrossAttention(nn.Module):
517
+ r"""
518
+ A cross attention layer.
519
+
520
+ Parameters:
521
+ query_dim (`int`): The number of channels in the query.
522
+ cross_attention_dim (`int`, *optional*):
523
+ The number of channels in the context. If not given, defaults to `query_dim`.
524
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
525
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
526
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
527
+ bias (`bool`, *optional*, defaults to False):
528
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ query_dim: int,
534
+ cross_attention_dim: Optional[int] = None,
535
+ heads: int = 8,
536
+ dim_head: int = 64,
537
+ dropout: float = 0.0,
538
+ bias=False,
539
+ ):
540
+ super().__init__()
541
+ inner_dim = dim_head * heads
542
+ self.is_cross_attn = cross_attention_dim is not None
543
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
544
+
545
+ self.scale = dim_head**-0.5
546
+ self.heads = heads
547
+ # for slice_size > 0 the attention score computation
548
+ # is split across the batch axis to save memory
549
+ # You can set slice_size with `set_attention_slice`
550
+ self._slice_size = None
551
+ self._use_memory_efficient_attention_xformers = False
552
+
553
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
554
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
555
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
556
+
557
+ self.to_out = nn.ModuleList([])
558
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
559
+ self.to_out.append(nn.Dropout(dropout))
560
+
561
+ def reshape_heads_to_batch_dim(self, tensor):
562
+ batch_size, seq_len, dim = tensor.shape
563
+ head_size = self.heads
564
+ tensor = tensor.reshape(batch_size, seq_len,
565
+ head_size, dim // head_size)
566
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
567
+ batch_size * head_size, seq_len, dim // head_size)
568
+ return tensor
569
+
570
+ def reshape_batch_dim_to_heads(self, tensor):
571
+ batch_size, seq_len, dim = tensor.shape
572
+ head_size = self.heads
573
+ tensor = tensor.reshape(batch_size // head_size,
574
+ head_size, seq_len, dim)
575
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
576
+ batch_size // head_size, seq_len, dim * head_size)
577
+ return tensor
578
+
579
+ def reshape_batch_dim_to_heads_and_average(self, tensor):
580
+ batch_size, seq_len, seq_len2 = tensor.shape
581
+ head_size = self.heads
582
+ tensor = tensor.reshape(batch_size // head_size,
583
+ head_size, seq_len, seq_len2)
584
+ return tensor.mean(1)
585
+
586
+ def forward(self, hidden_states, context=None, mask=None, text_format_dict={}):
587
+ batch_size, sequence_length, _ = hidden_states.shape
588
+
589
+ query = self.to_q(hidden_states)
590
+ context = context if context is not None else hidden_states
591
+ key = self.to_k(context)
592
+ value = self.to_v(context)
593
+
594
+ dim = query.shape[-1]
595
+
596
+ query = self.reshape_heads_to_batch_dim(query)
597
+ key = self.reshape_heads_to_batch_dim(key)
598
+ value = self.reshape_heads_to_batch_dim(value)
599
+
600
+ # attention, what we cannot get enough of
601
+ if self._use_memory_efficient_attention_xformers:
602
+ hidden_states = self._memory_efficient_attention_xformers(
603
+ query, key, value)
604
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
605
+ hidden_states = hidden_states.to(query.dtype)
606
+ else:
607
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
608
+ # only this attention function is used
609
+ hidden_states, attn_probs = self._attention(
610
+ query, key, value, **text_format_dict)
611
+
612
+ # linear proj
613
+ hidden_states = self.to_out[0](hidden_states)
614
+ # dropout
615
+ hidden_states = self.to_out[1](hidden_states)
616
+ return hidden_states, attn_probs
617
+
618
+ def _qk(self, query, key):
619
+ return torch.baddbmm(
620
+ torch.empty(query.shape[0], query.shape[1], key.shape[1],
621
+ dtype=query.dtype, device=query.device),
622
+ query,
623
+ key.transpose(-1, -2),
624
+ beta=0,
625
+ alpha=self.scale,
626
+ )
627
+
628
+ def _attention(self, query, key, value, word_pos=None, font_size=None,
629
+ **kwargs):
630
+ attention_scores = self._qk(query, key)
631
+
632
+ # Font size:
633
+ if self.is_cross_attn and word_pos is not None and font_size is not None:
634
+ assert key.shape[1] == 77
635
+ attention_score_exp = attention_scores.exp()
636
+ font_size_abs, font_size_sign = font_size.abs(), font_size.sign()
637
+ attention_score_exp[:, :, word_pos] = attention_score_exp[:, :, word_pos].clone(
638
+ )*font_size_abs
639
+ attention_probs = attention_score_exp / \
640
+ attention_score_exp.sum(-1, True)
641
+ attention_probs[:, :, word_pos] *= font_size_sign
642
+ else:
643
+ attention_probs = attention_scores.softmax(dim=-1)
644
+
645
+ hidden_states = torch.bmm(attention_probs, value)
646
+
647
+ # reshape hidden_states
648
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
649
+ attention_probs = self.reshape_batch_dim_to_heads_and_average(
650
+ attention_probs)
651
+ return hidden_states, attention_probs
652
+
653
+ def _memory_efficient_attention_xformers(self, query, key, value):
654
+ query = query.contiguous()
655
+ key = key.contiguous()
656
+ value = value.contiguous()
657
+ hidden_states = xformers.ops.memory_efficient_attention(
658
+ query, key, value, attn_bias=None)
659
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
660
+ return hidden_states
661
+
662
+
663
+ class FeedForward(nn.Module):
664
+ r"""
665
+ A feed-forward layer.
666
+
667
+ Parameters:
668
+ dim (`int`): The number of channels in the input.
669
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
670
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
671
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
672
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
673
+ """
674
+
675
+ def __init__(
676
+ self,
677
+ dim: int,
678
+ dim_out: Optional[int] = None,
679
+ mult: int = 4,
680
+ dropout: float = 0.0,
681
+ activation_fn: str = "geglu",
682
+ ):
683
+ super().__init__()
684
+ inner_dim = int(dim * mult)
685
+ dim_out = dim_out if dim_out is not None else dim
686
+
687
+ if activation_fn == "geglu":
688
+ geglu = GEGLU(dim, inner_dim)
689
+ elif activation_fn == "geglu-approximate":
690
+ geglu = ApproximateGELU(dim, inner_dim)
691
+
692
+ self.net = nn.ModuleList([])
693
+ # project in
694
+ self.net.append(geglu)
695
+ # project dropout
696
+ self.net.append(nn.Dropout(dropout))
697
+ # project out
698
+ self.net.append(nn.Linear(inner_dim, dim_out))
699
+
700
+ def forward(self, hidden_states):
701
+ for module in self.net:
702
+ hidden_states = module(hidden_states)
703
+ return hidden_states
704
+
705
+
706
+ # feedforward
707
+ class GEGLU(nn.Module):
708
+ r"""
709
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
710
+
711
+ Parameters:
712
+ dim_in (`int`): The number of channels in the input.
713
+ dim_out (`int`): The number of channels in the output.
714
+ """
715
+
716
+ def __init__(self, dim_in: int, dim_out: int):
717
+ super().__init__()
718
+ self.proj = nn.Linear(dim_in, dim_out * 2)
719
+
720
+ def gelu(self, gate):
721
+ if gate.device.type != "mps":
722
+ return F.gelu(gate)
723
+ # mps: gelu is not implemented for float16
724
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
725
+
726
+ def forward(self, hidden_states):
727
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
728
+ return hidden_states * self.gelu(gate)
729
+
730
+
731
+ class ApproximateGELU(nn.Module):
732
+ """
733
+ The approximate form of Gaussian Error Linear Unit (GELU)
734
+
735
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
736
+ """
737
+
738
+ def __init__(self, dim_in: int, dim_out: int):
739
+ super().__init__()
740
+ self.proj = nn.Linear(dim_in, dim_out)
741
+
742
+ def forward(self, x):
743
+ x = self.proj(x)
744
+ return x * torch.sigmoid(1.702 * x)
745
+
746
+
747
+ class AdaLayerNorm(nn.Module):
748
+ """
749
+ Norm layer modified to incorporate timestep embeddings.
750
+ """
751
+
752
+ def __init__(self, embedding_dim, num_embeddings):
753
+ super().__init__()
754
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
755
+ self.silu = nn.SiLU()
756
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
757
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
758
+
759
+ def forward(self, x, timestep):
760
+ emb = self.linear(self.silu(self.emb(timestep)))
761
+ scale, shift = torch.chunk(emb, 2)
762
+ x = self.norm(x) * (1 + scale) + shift
763
+ return x
764
+
765
+
766
+ class DualTransformer2DModel(nn.Module):
767
+ """
768
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
769
+
770
+ Parameters:
771
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
772
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
773
+ in_channels (`int`, *optional*):
774
+ Pass if the input is continuous. The number of channels in the input and output.
775
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
776
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
777
+ cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
778
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
779
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
780
+ `ImagePositionalEmbeddings`.
781
+ num_vector_embeds (`int`, *optional*):
782
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
783
+ Includes the class for the masked latent pixel.
784
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
785
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
786
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
787
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
788
+ up to but not more than steps than `num_embeds_ada_norm`.
789
+ attention_bias (`bool`, *optional*):
790
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
791
+ """
792
+
793
+ def __init__(
794
+ self,
795
+ num_attention_heads: int = 16,
796
+ attention_head_dim: int = 88,
797
+ in_channels: Optional[int] = None,
798
+ num_layers: int = 1,
799
+ dropout: float = 0.0,
800
+ norm_num_groups: int = 32,
801
+ cross_attention_dim: Optional[int] = None,
802
+ attention_bias: bool = False,
803
+ sample_size: Optional[int] = None,
804
+ num_vector_embeds: Optional[int] = None,
805
+ activation_fn: str = "geglu",
806
+ num_embeds_ada_norm: Optional[int] = None,
807
+ ):
808
+ super().__init__()
809
+ self.transformers = nn.ModuleList(
810
+ [
811
+ Transformer2DModel(
812
+ num_attention_heads=num_attention_heads,
813
+ attention_head_dim=attention_head_dim,
814
+ in_channels=in_channels,
815
+ num_layers=num_layers,
816
+ dropout=dropout,
817
+ norm_num_groups=norm_num_groups,
818
+ cross_attention_dim=cross_attention_dim,
819
+ attention_bias=attention_bias,
820
+ sample_size=sample_size,
821
+ num_vector_embeds=num_vector_embeds,
822
+ activation_fn=activation_fn,
823
+ num_embeds_ada_norm=num_embeds_ada_norm,
824
+ )
825
+ for _ in range(2)
826
+ ]
827
+ )
828
+
829
+ # Variables that can be set by a pipeline:
830
+
831
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
832
+ self.mix_ratio = 0.5
833
+
834
+ # The shape of `encoder_hidden_states` is expected to be
835
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
836
+ self.condition_lengths = [77, 257]
837
+
838
+ # Which transformer to use to encode which condition.
839
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
840
+ self.transformer_index_for_condition = [1, 0]
841
+
842
+ def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
843
+ """
844
+ Args:
845
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
846
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
847
+ hidden_states
848
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
849
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
850
+ self-attention.
851
+ timestep ( `torch.long`, *optional*):
852
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
853
+ return_dict (`bool`, *optional*, defaults to `True`):
854
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
855
+
856
+ Returns:
857
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
858
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
859
+ tensor.
860
+ """
861
+ input_states = hidden_states
862
+
863
+ encoded_states = []
864
+ tokens_start = 0
865
+ for i in range(2):
866
+ # for each of the two transformers, pass the corresponding condition tokens
867
+ condition_state = encoder_hidden_states[:,
868
+ tokens_start: tokens_start + self.condition_lengths[i]]
869
+ transformer_index = self.transformer_index_for_condition[i]
870
+ encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
871
+ 0
872
+ ]
873
+ encoded_states.append(encoded_state - input_states)
874
+ tokens_start += self.condition_lengths[i]
875
+
876
+ output_states = encoded_states[0] * self.mix_ratio + \
877
+ encoded_states[1] * (1 - self.mix_ratio)
878
+ output_states = output_states + input_states
879
+
880
+ if not return_dict:
881
+ return (output_states,)
882
+
883
+ return Transformer2DModelOutput(sample=output_states)
884
+
885
+ def _set_attention_slice(self, slice_size):
886
+ for transformer in self.transformers:
887
+ transformer._set_attention_slice(slice_size)
888
+
889
+ def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
890
+ for transformer in self.transformers:
891
+ transformer._set_use_memory_efficient_attention_xformers(
892
+ use_memory_efficient_attention_xformers)
models/region_diffusion.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import collections
4
+ import torch.nn as nn
5
+ from functools import partial
6
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
7
+ from diffusers import AutoencoderKL, PNDMScheduler, EulerDiscreteScheduler, DPMSolverMultistepScheduler
8
+ from models.unet_2d_condition import UNet2DConditionModel
9
+
10
+ # suppress partial model loading warning
11
+ logging.set_verbosity_error()
12
+
13
+
14
+ class RegionDiffusion(nn.Module):
15
+ def __init__(self, device):
16
+ super().__init__()
17
+
18
+ try:
19
+ with open('./TOKEN', 'r') as f:
20
+ self.token = f.read().replace('\n', '') # remove the last \n!
21
+ print(f'[INFO] loaded hugging face access token from ./TOKEN!')
22
+ except FileNotFoundError as e:
23
+ self.token = True
24
+ print(f'[INFO] try to load hugging face access token from the default place, make sure you have run `huggingface-cli login`.')
25
+
26
+ self.device = device
27
+ self.num_train_timesteps = 1000
28
+ self.clip_gradient = False
29
+
30
+ print(f'[INFO] loading stable diffusion...')
31
+ local_pretrained_dir = f'pretrained-guidance/v1'
32
+ if not os.path.isdir(local_pretrained_dir):
33
+ save_pretrained = True
34
+ load_paths = 'runwayml/stable-diffusion-v1-5'
35
+ os.makedirs(local_pretrained_dir, exist_ok=True)
36
+ else:
37
+ save_pretrained = False
38
+ load_paths = local_pretrained_dir
39
+
40
+ # 1. Load the autoencoder model which will be used to decode the latents into image space.
41
+ self.vae = AutoencoderKL.from_pretrained(
42
+ load_paths, subfolder="vae", use_auth_token=self.token).to(self.device)
43
+
44
+ # 2. Load the tokenizer and text encoder to tokenize and encode the text.
45
+ self.tokenizer = CLIPTokenizer.from_pretrained(
46
+ load_paths, subfolder='tokenizer', use_auth_token=self.token)
47
+ self.text_encoder = CLIPTextModel.from_pretrained(
48
+ load_paths, subfolder='text_encoder', use_auth_token=self.token).to(self.device)
49
+
50
+ # 3. The UNet model for generating the latents.
51
+ self.unet = UNet2DConditionModel.from_pretrained(
52
+ load_paths, subfolder="unet", use_auth_token=self.token).to(self.device)
53
+
54
+ if save_pretrained:
55
+ self.vae.save_pretrained(os.path.join(local_pretrained_dir, 'vae'))
56
+ self.tokenizer.save_pretrained(
57
+ os.path.join(local_pretrained_dir, 'tokenizer'))
58
+ self.text_encoder.save_pretrained(
59
+ os.path.join(local_pretrained_dir, 'text_encoder'))
60
+ self.unet.save_pretrained(
61
+ os.path.join(local_pretrained_dir, 'unet'))
62
+
63
+ self.scheduler = PNDMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
64
+ num_train_timesteps=self.num_train_timesteps, skip_prk_steps=True, steps_offset=1)
65
+ self.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
66
+
67
+ self.masks = []
68
+ self.attention_maps = None
69
+ self.color_loss = torch.nn.functional.mse_loss
70
+
71
+ print(f'[INFO] loaded stable diffusion!')
72
+
73
+ def get_text_embeds(self, prompt, negative_prompt):
74
+ # prompt, negative_prompt: [str]
75
+
76
+ # Tokenize text and get embeddings
77
+ text_input = self.tokenizer(
78
+ prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
79
+
80
+ with torch.no_grad():
81
+ text_embeddings = self.text_encoder(
82
+ text_input.input_ids.to(self.device))[0]
83
+
84
+ # Do the same for unconditional embeddings
85
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length',
86
+ max_length=self.tokenizer.model_max_length, return_tensors='pt')
87
+
88
+ with torch.no_grad():
89
+ uncond_embeddings = self.text_encoder(
90
+ uncond_input.input_ids.to(self.device))[0]
91
+
92
+ # Cat for final embeddings
93
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
94
+ return text_embeddings
95
+
96
+ def get_text_embeds_list(self, prompts):
97
+ # prompts: [list]
98
+ text_embeddings = []
99
+ for prompt in prompts:
100
+ # Tokenize text and get embeddings
101
+ text_input = self.tokenizer(
102
+ [prompt], padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
103
+
104
+ with torch.no_grad():
105
+ text_embeddings.append(self.text_encoder(
106
+ text_input.input_ids.to(self.device))[0])
107
+
108
+ return text_embeddings
109
+
110
+ def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5,
111
+ latents=None, use_grad_guidance=False, text_format_dict={}):
112
+
113
+ if latents is None:
114
+ latents = torch.randn(
115
+ (1, self.unet.in_channels, height // 8, width // 8), device=self.device)
116
+
117
+ self.scheduler.set_timesteps(num_inference_steps)
118
+ n_styles = text_embeddings.shape[0]-1
119
+ assert n_styles == len(self.masks)
120
+
121
+ with torch.autocast('cuda'):
122
+ for i, t in enumerate(self.scheduler.timesteps):
123
+
124
+ # predict the noise residual
125
+ with torch.no_grad():
126
+ noise_pred_uncond = self.unet(latents, t, encoder_hidden_states=text_embeddings[:1],
127
+ text_format_dict={})['sample']
128
+ noise_pred_text = None
129
+ for style_i, mask in enumerate(self.masks):
130
+ if style_i < len(self.masks) - 1:
131
+ masked_latent = latents
132
+ noise_pred_text_cur = self.unet(masked_latent, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
133
+ text_format_dict={})['sample']
134
+ else:
135
+ noise_pred_text_cur = self.unet(latents, t, encoder_hidden_states=text_embeddings[style_i+1:style_i+2],
136
+ text_format_dict=text_format_dict)['sample']
137
+ if noise_pred_text is None:
138
+ noise_pred_text = noise_pred_text_cur * mask
139
+ else:
140
+ noise_pred_text = noise_pred_text + noise_pred_text_cur*mask
141
+
142
+ # perform classifier-free guidance
143
+ noise_pred = noise_pred_uncond + guidance_scale * \
144
+ (noise_pred_text - noise_pred_uncond)
145
+
146
+ # compute the previous noisy sample x_t -> x_t-1
147
+ latents = self.scheduler.step(noise_pred, t, latents)[
148
+ 'prev_sample']
149
+
150
+ # apply gradient guidance
151
+ if use_grad_guidance and t < text_format_dict['guidance_start_step']:
152
+ with torch.enable_grad():
153
+ if not latents.requires_grad:
154
+ latents.requires_grad = True
155
+ latents_0 = self.predict_x0(latents, noise_pred, t)
156
+ latents_inp = 1 / 0.18215 * latents_0
157
+ imgs = self.vae.decode(latents_inp).sample
158
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
159
+ loss_total = 0.
160
+ for attn_map, rgb_val in zip(text_format_dict['color_obj_atten'], text_format_dict['target_RGB']):
161
+ avg_rgb = (
162
+ imgs*attn_map[:, 0]).sum(2).sum(2)/attn_map[:, 0].sum()
163
+ loss = self.color_loss(
164
+ avg_rgb, rgb_val[:, :, 0, 0])*100
165
+ # print(loss)
166
+ loss_total += loss
167
+ loss_total.backward()
168
+ latents = (
169
+ latents - latents.grad * text_format_dict['color_guidance_weight']).detach().clone()
170
+
171
+ return latents
172
+
173
+ def predict_x0(self, x_t, eps_t, t):
174
+ alpha_t = self.scheduler.alphas_cumprod[t]
175
+ return (x_t - eps_t * torch.sqrt(1-alpha_t)) / torch.sqrt(alpha_t)
176
+
177
+ def produce_attn_maps(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
178
+ guidance_scale=7.5, latents=None):
179
+
180
+ if isinstance(prompts, str):
181
+ prompts = [prompts]
182
+
183
+ if isinstance(negative_prompts, str):
184
+ negative_prompts = [negative_prompts]
185
+
186
+ # Prompts -> text embeds
187
+ text_embeddings = self.get_text_embeds(
188
+ prompts, negative_prompts) # [2, 77, 768]
189
+ if latents is None:
190
+ latents = torch.randn(
191
+ (text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)
192
+
193
+ self.scheduler.set_timesteps(num_inference_steps)
194
+
195
+ with torch.autocast('cuda'):
196
+ for i, t in enumerate(self.scheduler.timesteps):
197
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
198
+ latent_model_input = torch.cat([latents] * 2)
199
+
200
+ # predict the noise residual
201
+ with torch.no_grad():
202
+ noise_pred = self.unet(
203
+ latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
204
+
205
+ # perform guidance
206
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
207
+ noise_pred = noise_pred_uncond + guidance_scale * \
208
+ (noise_pred_text - noise_pred_uncond)
209
+
210
+ # compute the previous noisy sample x_t -> x_t-1
211
+ latents = self.scheduler.step(noise_pred, t, latents)[
212
+ 'prev_sample']
213
+
214
+ # Img latents -> imgs
215
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
216
+
217
+ # Img to Numpy
218
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
219
+ imgs = (imgs * 255).round().astype('uint8')
220
+
221
+ return imgs
222
+
223
+ def decode_latents(self, latents):
224
+
225
+ latents = 1 / 0.18215 * latents
226
+
227
+ with torch.no_grad():
228
+ imgs = self.vae.decode(latents).sample
229
+
230
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
231
+
232
+ return imgs
233
+
234
+ def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50,
235
+ guidance_scale=7.5, latents=None, text_format_dict={}, use_grad_guidance=False):
236
+
237
+ if isinstance(prompts, str):
238
+ prompts = [prompts]
239
+
240
+ if isinstance(negative_prompts, str):
241
+ negative_prompts = [negative_prompts]
242
+
243
+ # Prompts -> text embeds
244
+ text_embeds = self.get_text_embeds(
245
+ prompts, negative_prompts) # [2, 77, 768]
246
+
247
+ if len(text_format_dict) > 0:
248
+ if 'font_styles' in text_format_dict and text_format_dict['font_styles'] is not None:
249
+ text_format_dict['font_styles_embs'] = self.get_text_embeds_list(
250
+ text_format_dict['font_styles']) # [2, 77, 768]
251
+ else:
252
+ text_format_dict['font_styles_embs'] = None
253
+
254
+ # else:
255
+ latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents,
256
+ num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
257
+ use_grad_guidance=use_grad_guidance, text_format_dict=text_format_dict) # [1, 4, 64, 64]
258
+
259
+ # Img latents -> imgs
260
+ imgs = self.decode_latents(latents) # [1, 3, 512, 512]
261
+
262
+ # Img to Numpy
263
+ imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
264
+ imgs = (imgs * 255).round().astype('uint8')
265
+
266
+ return imgs
267
+
268
+ def reset_attention_maps(self):
269
+ r"""Function to reset attention maps.
270
+ We reset attention maps because we append them while getting hooks
271
+ to visualize attention maps for every step.
272
+ """
273
+ for key in self.attention_maps:
274
+ self.attention_maps[key] = []
275
+
276
+ def register_evaluation_hooks(self):
277
+ r"""Function for registering hooks during evaluation.
278
+ We mainly store activation maps averaged over queries.
279
+ """
280
+ self.forward_hooks = []
281
+
282
+ def save_activations(activations, name, module, inp, out):
283
+ r"""
284
+ PyTorch Forward hook to save outputs at each forward pass.
285
+ """
286
+ # out[0] - final output of attention layer
287
+ # out[1] - attention probability matrix
288
+ if 'attn2' in name:
289
+ assert out[1].shape[-1] == 77
290
+ activations[name].append(out[1].detach().cpu())
291
+ else:
292
+ assert out[1].shape[-1] != 77
293
+ attention_dict = collections.defaultdict(list)
294
+ for name, module in self.unet.named_modules():
295
+ leaf_name = name.split('.')[-1]
296
+ if 'attn' in leaf_name:
297
+ # Register hook to obtain outputs at every attention layer.
298
+ self.forward_hooks.append(module.register_forward_hook(
299
+ partial(save_activations, attention_dict, name)
300
+ ))
301
+ # attention_dict is a dictionary containing attention maps for every attention layer
302
+ self.attention_maps = attention_dict
303
+
304
+ def remove_evaluation_hooks(self):
305
+ for hook in self.forward_hooks:
306
+ hook.remove()
307
+ self.attention_maps = None
models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+
18
+ from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
19
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
20
+
21
+
22
+ def get_down_block(
23
+ down_block_type,
24
+ num_layers,
25
+ in_channels,
26
+ out_channels,
27
+ temb_channels,
28
+ add_downsample,
29
+ resnet_eps,
30
+ resnet_act_fn,
31
+ attn_num_head_channels,
32
+ resnet_groups=None,
33
+ cross_attention_dim=None,
34
+ downsample_padding=None,
35
+ dual_cross_attention=False,
36
+ use_linear_projection=False,
37
+ only_cross_attention=False,
38
+ ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
40
+ if down_block_type == "DownBlock2D":
41
+ return DownBlock2D(
42
+ num_layers=num_layers,
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ temb_channels=temb_channels,
46
+ add_downsample=add_downsample,
47
+ resnet_eps=resnet_eps,
48
+ resnet_act_fn=resnet_act_fn,
49
+ resnet_groups=resnet_groups,
50
+ downsample_padding=downsample_padding,
51
+ )
52
+ elif down_block_type == "AttnDownBlock2D":
53
+ return AttnDownBlock2D(
54
+ num_layers=num_layers,
55
+ in_channels=in_channels,
56
+ out_channels=out_channels,
57
+ temb_channels=temb_channels,
58
+ add_downsample=add_downsample,
59
+ resnet_eps=resnet_eps,
60
+ resnet_act_fn=resnet_act_fn,
61
+ resnet_groups=resnet_groups,
62
+ downsample_padding=downsample_padding,
63
+ attn_num_head_channels=attn_num_head_channels,
64
+ )
65
+ elif down_block_type == "CrossAttnDownBlock2D":
66
+ if cross_attention_dim is None:
67
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
68
+ return CrossAttnDownBlock2D(
69
+ num_layers=num_layers,
70
+ in_channels=in_channels,
71
+ out_channels=out_channels,
72
+ temb_channels=temb_channels,
73
+ add_downsample=add_downsample,
74
+ resnet_eps=resnet_eps,
75
+ resnet_act_fn=resnet_act_fn,
76
+ resnet_groups=resnet_groups,
77
+ downsample_padding=downsample_padding,
78
+ cross_attention_dim=cross_attention_dim,
79
+ attn_num_head_channels=attn_num_head_channels,
80
+ dual_cross_attention=dual_cross_attention,
81
+ use_linear_projection=use_linear_projection,
82
+ only_cross_attention=only_cross_attention,
83
+ )
84
+ elif down_block_type == "SkipDownBlock2D":
85
+ return SkipDownBlock2D(
86
+ num_layers=num_layers,
87
+ in_channels=in_channels,
88
+ out_channels=out_channels,
89
+ temb_channels=temb_channels,
90
+ add_downsample=add_downsample,
91
+ resnet_eps=resnet_eps,
92
+ resnet_act_fn=resnet_act_fn,
93
+ downsample_padding=downsample_padding,
94
+ )
95
+ elif down_block_type == "AttnSkipDownBlock2D":
96
+ return AttnSkipDownBlock2D(
97
+ num_layers=num_layers,
98
+ in_channels=in_channels,
99
+ out_channels=out_channels,
100
+ temb_channels=temb_channels,
101
+ add_downsample=add_downsample,
102
+ resnet_eps=resnet_eps,
103
+ resnet_act_fn=resnet_act_fn,
104
+ downsample_padding=downsample_padding,
105
+ attn_num_head_channels=attn_num_head_channels,
106
+ )
107
+ elif down_block_type == "DownEncoderBlock2D":
108
+ return DownEncoderBlock2D(
109
+ num_layers=num_layers,
110
+ in_channels=in_channels,
111
+ out_channels=out_channels,
112
+ add_downsample=add_downsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ downsample_padding=downsample_padding,
117
+ )
118
+ elif down_block_type == "AttnDownEncoderBlock2D":
119
+ return AttnDownEncoderBlock2D(
120
+ num_layers=num_layers,
121
+ in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ add_downsample=add_downsample,
124
+ resnet_eps=resnet_eps,
125
+ resnet_act_fn=resnet_act_fn,
126
+ resnet_groups=resnet_groups,
127
+ downsample_padding=downsample_padding,
128
+ attn_num_head_channels=attn_num_head_channels,
129
+ )
130
+ raise ValueError(f"{down_block_type} does not exist.")
131
+
132
+
133
+ def get_up_block(
134
+ up_block_type,
135
+ num_layers,
136
+ in_channels,
137
+ out_channels,
138
+ prev_output_channel,
139
+ temb_channels,
140
+ add_upsample,
141
+ resnet_eps,
142
+ resnet_act_fn,
143
+ attn_num_head_channels,
144
+ resnet_groups=None,
145
+ cross_attention_dim=None,
146
+ dual_cross_attention=False,
147
+ use_linear_projection=False,
148
+ only_cross_attention=False,
149
+ ):
150
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
151
+ if up_block_type == "UpBlock2D":
152
+ return UpBlock2D(
153
+ num_layers=num_layers,
154
+ in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ prev_output_channel=prev_output_channel,
157
+ temb_channels=temb_channels,
158
+ add_upsample=add_upsample,
159
+ resnet_eps=resnet_eps,
160
+ resnet_act_fn=resnet_act_fn,
161
+ resnet_groups=resnet_groups,
162
+ )
163
+ elif up_block_type == "CrossAttnUpBlock2D":
164
+ if cross_attention_dim is None:
165
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
166
+ return CrossAttnUpBlock2D(
167
+ num_layers=num_layers,
168
+ in_channels=in_channels,
169
+ out_channels=out_channels,
170
+ prev_output_channel=prev_output_channel,
171
+ temb_channels=temb_channels,
172
+ add_upsample=add_upsample,
173
+ resnet_eps=resnet_eps,
174
+ resnet_act_fn=resnet_act_fn,
175
+ resnet_groups=resnet_groups,
176
+ cross_attention_dim=cross_attention_dim,
177
+ attn_num_head_channels=attn_num_head_channels,
178
+ dual_cross_attention=dual_cross_attention,
179
+ use_linear_projection=use_linear_projection,
180
+ only_cross_attention=only_cross_attention,
181
+ )
182
+ elif up_block_type == "AttnUpBlock2D":
183
+ return AttnUpBlock2D(
184
+ num_layers=num_layers,
185
+ in_channels=in_channels,
186
+ out_channels=out_channels,
187
+ prev_output_channel=prev_output_channel,
188
+ temb_channels=temb_channels,
189
+ add_upsample=add_upsample,
190
+ resnet_eps=resnet_eps,
191
+ resnet_act_fn=resnet_act_fn,
192
+ resnet_groups=resnet_groups,
193
+ attn_num_head_channels=attn_num_head_channels,
194
+ )
195
+ elif up_block_type == "SkipUpBlock2D":
196
+ return SkipUpBlock2D(
197
+ num_layers=num_layers,
198
+ in_channels=in_channels,
199
+ out_channels=out_channels,
200
+ prev_output_channel=prev_output_channel,
201
+ temb_channels=temb_channels,
202
+ add_upsample=add_upsample,
203
+ resnet_eps=resnet_eps,
204
+ resnet_act_fn=resnet_act_fn,
205
+ )
206
+ elif up_block_type == "AttnSkipUpBlock2D":
207
+ return AttnSkipUpBlock2D(
208
+ num_layers=num_layers,
209
+ in_channels=in_channels,
210
+ out_channels=out_channels,
211
+ prev_output_channel=prev_output_channel,
212
+ temb_channels=temb_channels,
213
+ add_upsample=add_upsample,
214
+ resnet_eps=resnet_eps,
215
+ resnet_act_fn=resnet_act_fn,
216
+ attn_num_head_channels=attn_num_head_channels,
217
+ )
218
+ elif up_block_type == "UpDecoderBlock2D":
219
+ return UpDecoderBlock2D(
220
+ num_layers=num_layers,
221
+ in_channels=in_channels,
222
+ out_channels=out_channels,
223
+ add_upsample=add_upsample,
224
+ resnet_eps=resnet_eps,
225
+ resnet_act_fn=resnet_act_fn,
226
+ resnet_groups=resnet_groups,
227
+ )
228
+ elif up_block_type == "AttnUpDecoderBlock2D":
229
+ return AttnUpDecoderBlock2D(
230
+ num_layers=num_layers,
231
+ in_channels=in_channels,
232
+ out_channels=out_channels,
233
+ add_upsample=add_upsample,
234
+ resnet_eps=resnet_eps,
235
+ resnet_act_fn=resnet_act_fn,
236
+ resnet_groups=resnet_groups,
237
+ attn_num_head_channels=attn_num_head_channels,
238
+ )
239
+ raise ValueError(f"{up_block_type} does not exist.")
240
+
241
+
242
+ class UNetMidBlock2D(nn.Module):
243
+ def __init__(
244
+ self,
245
+ in_channels: int,
246
+ temb_channels: int,
247
+ dropout: float = 0.0,
248
+ num_layers: int = 1,
249
+ resnet_eps: float = 1e-6,
250
+ resnet_time_scale_shift: str = "default",
251
+ resnet_act_fn: str = "swish",
252
+ resnet_groups: int = 32,
253
+ resnet_pre_norm: bool = True,
254
+ attn_num_head_channels=1,
255
+ attention_type="default",
256
+ output_scale_factor=1.0,
257
+ ):
258
+ super().__init__()
259
+
260
+ self.attention_type = attention_type
261
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
262
+
263
+ # there is always at least one resnet
264
+ resnets = [
265
+ ResnetBlock2D(
266
+ in_channels=in_channels,
267
+ out_channels=in_channels,
268
+ temb_channels=temb_channels,
269
+ eps=resnet_eps,
270
+ groups=resnet_groups,
271
+ dropout=dropout,
272
+ time_embedding_norm=resnet_time_scale_shift,
273
+ non_linearity=resnet_act_fn,
274
+ output_scale_factor=output_scale_factor,
275
+ pre_norm=resnet_pre_norm,
276
+ )
277
+ ]
278
+ attentions = []
279
+
280
+ for _ in range(num_layers):
281
+ attentions.append(
282
+ AttentionBlock(
283
+ in_channels,
284
+ num_head_channels=attn_num_head_channels,
285
+ rescale_output_factor=output_scale_factor,
286
+ eps=resnet_eps,
287
+ norm_num_groups=resnet_groups,
288
+ )
289
+ )
290
+ resnets.append(
291
+ ResnetBlock2D(
292
+ in_channels=in_channels,
293
+ out_channels=in_channels,
294
+ temb_channels=temb_channels,
295
+ eps=resnet_eps,
296
+ groups=resnet_groups,
297
+ dropout=dropout,
298
+ time_embedding_norm=resnet_time_scale_shift,
299
+ non_linearity=resnet_act_fn,
300
+ output_scale_factor=output_scale_factor,
301
+ pre_norm=resnet_pre_norm,
302
+ )
303
+ )
304
+
305
+ self.attentions = nn.ModuleList(attentions)
306
+ self.resnets = nn.ModuleList(resnets)
307
+
308
+ def forward(self, hidden_states, temb=None, encoder_states=None):
309
+ hidden_states = self.resnets[0](hidden_states, temb)
310
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
311
+ if self.attention_type == "default":
312
+ hidden_states = attn(hidden_states)
313
+ else:
314
+ hidden_states = attn(hidden_states, encoder_states)
315
+ hidden_states = resnet(hidden_states, temb)
316
+
317
+ return hidden_states
318
+
319
+
320
+ class UNetMidBlock2DCrossAttn(nn.Module):
321
+ def __init__(
322
+ self,
323
+ in_channels: int,
324
+ temb_channels: int,
325
+ dropout: float = 0.0,
326
+ num_layers: int = 1,
327
+ resnet_eps: float = 1e-6,
328
+ resnet_time_scale_shift: str = "default",
329
+ resnet_act_fn: str = "swish",
330
+ resnet_groups: int = 32,
331
+ resnet_pre_norm: bool = True,
332
+ attn_num_head_channels=1,
333
+ attention_type="default",
334
+ output_scale_factor=1.0,
335
+ cross_attention_dim=1280,
336
+ dual_cross_attention=False,
337
+ use_linear_projection=False,
338
+ ):
339
+ super().__init__()
340
+
341
+ self.attention_type = attention_type
342
+ self.attn_num_head_channels = attn_num_head_channels
343
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
344
+
345
+ # there is always at least one resnet
346
+ resnets = [
347
+ ResnetBlock2D(
348
+ in_channels=in_channels,
349
+ out_channels=in_channels,
350
+ temb_channels=temb_channels,
351
+ eps=resnet_eps,
352
+ groups=resnet_groups,
353
+ dropout=dropout,
354
+ time_embedding_norm=resnet_time_scale_shift,
355
+ non_linearity=resnet_act_fn,
356
+ output_scale_factor=output_scale_factor,
357
+ pre_norm=resnet_pre_norm,
358
+ )
359
+ ]
360
+ attentions = []
361
+
362
+ for _ in range(num_layers):
363
+ if not dual_cross_attention:
364
+ attentions.append(
365
+ Transformer2DModel(
366
+ attn_num_head_channels,
367
+ in_channels // attn_num_head_channels,
368
+ in_channels=in_channels,
369
+ num_layers=1,
370
+ cross_attention_dim=cross_attention_dim,
371
+ norm_num_groups=resnet_groups,
372
+ use_linear_projection=use_linear_projection,
373
+ )
374
+ )
375
+ else:
376
+ attentions.append(
377
+ DualTransformer2DModel(
378
+ attn_num_head_channels,
379
+ in_channels // attn_num_head_channels,
380
+ in_channels=in_channels,
381
+ num_layers=1,
382
+ cross_attention_dim=cross_attention_dim,
383
+ norm_num_groups=resnet_groups,
384
+ )
385
+ )
386
+ resnets.append(
387
+ ResnetBlock2D(
388
+ in_channels=in_channels,
389
+ out_channels=in_channels,
390
+ temb_channels=temb_channels,
391
+ eps=resnet_eps,
392
+ groups=resnet_groups,
393
+ dropout=dropout,
394
+ time_embedding_norm=resnet_time_scale_shift,
395
+ non_linearity=resnet_act_fn,
396
+ output_scale_factor=output_scale_factor,
397
+ pre_norm=resnet_pre_norm,
398
+ )
399
+ )
400
+
401
+ self.attentions = nn.ModuleList(attentions)
402
+ self.resnets = nn.ModuleList(resnets)
403
+
404
+ def set_attention_slice(self, slice_size):
405
+ head_dims = self.attn_num_head_channels
406
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
407
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
408
+ raise ValueError(
409
+ f"Make sure slice_size {slice_size} is a common divisor of "
410
+ f"the number of heads used in cross_attention: {head_dims}"
411
+ )
412
+ if slice_size is not None and slice_size > min(head_dims):
413
+ raise ValueError(
414
+ f"slice_size {slice_size} has to be smaller or equal to "
415
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
416
+ )
417
+
418
+ for attn in self.attentions:
419
+ attn._set_attention_slice(slice_size)
420
+
421
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
422
+ for attn in self.attentions:
423
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
424
+
425
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
426
+ text_format_dict={}):
427
+ hidden_states = self.resnets[0](hidden_states, temb)
428
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
429
+ hidden_states = attn(hidden_states, encoder_hidden_states,
430
+ text_format_dict).sample
431
+ hidden_states = resnet(hidden_states, temb)
432
+
433
+ return hidden_states
434
+
435
+
436
+ class AttnDownBlock2D(nn.Module):
437
+ def __init__(
438
+ self,
439
+ in_channels: int,
440
+ out_channels: int,
441
+ temb_channels: int,
442
+ dropout: float = 0.0,
443
+ num_layers: int = 1,
444
+ resnet_eps: float = 1e-6,
445
+ resnet_time_scale_shift: str = "default",
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ resnet_pre_norm: bool = True,
449
+ attn_num_head_channels=1,
450
+ attention_type="default",
451
+ output_scale_factor=1.0,
452
+ downsample_padding=1,
453
+ add_downsample=True,
454
+ ):
455
+ super().__init__()
456
+ resnets = []
457
+ attentions = []
458
+
459
+ self.attention_type = attention_type
460
+
461
+ for i in range(num_layers):
462
+ in_channels = in_channels if i == 0 else out_channels
463
+ resnets.append(
464
+ ResnetBlock2D(
465
+ in_channels=in_channels,
466
+ out_channels=out_channels,
467
+ temb_channels=temb_channels,
468
+ eps=resnet_eps,
469
+ groups=resnet_groups,
470
+ dropout=dropout,
471
+ time_embedding_norm=resnet_time_scale_shift,
472
+ non_linearity=resnet_act_fn,
473
+ output_scale_factor=output_scale_factor,
474
+ pre_norm=resnet_pre_norm,
475
+ )
476
+ )
477
+ attentions.append(
478
+ AttentionBlock(
479
+ out_channels,
480
+ num_head_channels=attn_num_head_channels,
481
+ rescale_output_factor=output_scale_factor,
482
+ eps=resnet_eps,
483
+ norm_num_groups=resnet_groups,
484
+ )
485
+ )
486
+
487
+ self.attentions = nn.ModuleList(attentions)
488
+ self.resnets = nn.ModuleList(resnets)
489
+
490
+ if add_downsample:
491
+ self.downsamplers = nn.ModuleList(
492
+ [
493
+ Downsample2D(
494
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
495
+ )
496
+ ]
497
+ )
498
+ else:
499
+ self.downsamplers = None
500
+
501
+ def forward(self, hidden_states, temb=None):
502
+ output_states = ()
503
+
504
+ for resnet, attn in zip(self.resnets, self.attentions):
505
+ hidden_states = resnet(hidden_states, temb)
506
+ hidden_states = attn(hidden_states)
507
+ output_states += (hidden_states,)
508
+
509
+ if self.downsamplers is not None:
510
+ for downsampler in self.downsamplers:
511
+ hidden_states = downsampler(hidden_states)
512
+
513
+ output_states += (hidden_states,)
514
+
515
+ return hidden_states, output_states
516
+
517
+
518
+ class CrossAttnDownBlock2D(nn.Module):
519
+ def __init__(
520
+ self,
521
+ in_channels: int,
522
+ out_channels: int,
523
+ temb_channels: int,
524
+ dropout: float = 0.0,
525
+ num_layers: int = 1,
526
+ resnet_eps: float = 1e-6,
527
+ resnet_time_scale_shift: str = "default",
528
+ resnet_act_fn: str = "swish",
529
+ resnet_groups: int = 32,
530
+ resnet_pre_norm: bool = True,
531
+ attn_num_head_channels=1,
532
+ cross_attention_dim=1280,
533
+ attention_type="default",
534
+ output_scale_factor=1.0,
535
+ downsample_padding=1,
536
+ add_downsample=True,
537
+ dual_cross_attention=False,
538
+ use_linear_projection=False,
539
+ only_cross_attention=False,
540
+ ):
541
+ super().__init__()
542
+ resnets = []
543
+ attentions = []
544
+
545
+ self.attention_type = attention_type
546
+ self.attn_num_head_channels = attn_num_head_channels
547
+
548
+ for i in range(num_layers):
549
+ in_channels = in_channels if i == 0 else out_channels
550
+ resnets.append(
551
+ ResnetBlock2D(
552
+ in_channels=in_channels,
553
+ out_channels=out_channels,
554
+ temb_channels=temb_channels,
555
+ eps=resnet_eps,
556
+ groups=resnet_groups,
557
+ dropout=dropout,
558
+ time_embedding_norm=resnet_time_scale_shift,
559
+ non_linearity=resnet_act_fn,
560
+ output_scale_factor=output_scale_factor,
561
+ pre_norm=resnet_pre_norm,
562
+ )
563
+ )
564
+ if not dual_cross_attention:
565
+ attentions.append(
566
+ Transformer2DModel(
567
+ attn_num_head_channels,
568
+ out_channels // attn_num_head_channels,
569
+ in_channels=out_channels,
570
+ num_layers=1,
571
+ cross_attention_dim=cross_attention_dim,
572
+ norm_num_groups=resnet_groups,
573
+ use_linear_projection=use_linear_projection,
574
+ only_cross_attention=only_cross_attention,
575
+ )
576
+ )
577
+ else:
578
+ attentions.append(
579
+ DualTransformer2DModel(
580
+ attn_num_head_channels,
581
+ out_channels // attn_num_head_channels,
582
+ in_channels=out_channels,
583
+ num_layers=1,
584
+ cross_attention_dim=cross_attention_dim,
585
+ norm_num_groups=resnet_groups,
586
+ )
587
+ )
588
+ self.attentions = nn.ModuleList(attentions)
589
+ self.resnets = nn.ModuleList(resnets)
590
+
591
+ if add_downsample:
592
+ self.downsamplers = nn.ModuleList(
593
+ [
594
+ Downsample2D(
595
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
596
+ )
597
+ ]
598
+ )
599
+ else:
600
+ self.downsamplers = None
601
+
602
+ self.gradient_checkpointing = False
603
+
604
+ def set_attention_slice(self, slice_size):
605
+ head_dims = self.attn_num_head_channels
606
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
607
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
608
+ raise ValueError(
609
+ f"Make sure slice_size {slice_size} is a common divisor of "
610
+ f"the number of heads used in cross_attention: {head_dims}"
611
+ )
612
+ if slice_size is not None and slice_size > min(head_dims):
613
+ raise ValueError(
614
+ f"slice_size {slice_size} has to be smaller or equal to "
615
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
616
+ )
617
+
618
+ for attn in self.attentions:
619
+ attn._set_attention_slice(slice_size)
620
+
621
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
622
+ for attn in self.attentions:
623
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
624
+
625
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None,
626
+ text_format_dict={}):
627
+ output_states = ()
628
+
629
+ for resnet, attn in zip(self.resnets, self.attentions):
630
+ if self.training and self.gradient_checkpointing:
631
+
632
+ def create_custom_forward(module, return_dict=None):
633
+ def custom_forward(*inputs):
634
+ if return_dict is not None:
635
+ return module(*inputs, return_dict=return_dict)
636
+ else:
637
+ return module(*inputs)
638
+
639
+ return custom_forward
640
+
641
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
642
+ hidden_states = torch.utils.checkpoint.checkpoint(
643
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
644
+ text_format_dict
645
+ )[0]
646
+ else:
647
+ hidden_states = resnet(hidden_states, temb)
648
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
649
+ text_format_dict=text_format_dict).sample
650
+
651
+ output_states += (hidden_states,)
652
+
653
+ if self.downsamplers is not None:
654
+ for downsampler in self.downsamplers:
655
+ hidden_states = downsampler(hidden_states)
656
+
657
+ output_states += (hidden_states,)
658
+
659
+ return hidden_states, output_states
660
+
661
+
662
+ class DownBlock2D(nn.Module):
663
+ def __init__(
664
+ self,
665
+ in_channels: int,
666
+ out_channels: int,
667
+ temb_channels: int,
668
+ dropout: float = 0.0,
669
+ num_layers: int = 1,
670
+ resnet_eps: float = 1e-6,
671
+ resnet_time_scale_shift: str = "default",
672
+ resnet_act_fn: str = "swish",
673
+ resnet_groups: int = 32,
674
+ resnet_pre_norm: bool = True,
675
+ output_scale_factor=1.0,
676
+ add_downsample=True,
677
+ downsample_padding=1,
678
+ ):
679
+ super().__init__()
680
+ resnets = []
681
+
682
+ for i in range(num_layers):
683
+ in_channels = in_channels if i == 0 else out_channels
684
+ resnets.append(
685
+ ResnetBlock2D(
686
+ in_channels=in_channels,
687
+ out_channels=out_channels,
688
+ temb_channels=temb_channels,
689
+ eps=resnet_eps,
690
+ groups=resnet_groups,
691
+ dropout=dropout,
692
+ time_embedding_norm=resnet_time_scale_shift,
693
+ non_linearity=resnet_act_fn,
694
+ output_scale_factor=output_scale_factor,
695
+ pre_norm=resnet_pre_norm,
696
+ )
697
+ )
698
+
699
+ self.resnets = nn.ModuleList(resnets)
700
+
701
+ if add_downsample:
702
+ self.downsamplers = nn.ModuleList(
703
+ [
704
+ Downsample2D(
705
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
706
+ )
707
+ ]
708
+ )
709
+ else:
710
+ self.downsamplers = None
711
+
712
+ self.gradient_checkpointing = False
713
+
714
+ def forward(self, hidden_states, temb=None):
715
+ output_states = ()
716
+
717
+ for resnet in self.resnets:
718
+ if self.training and self.gradient_checkpointing:
719
+
720
+ def create_custom_forward(module):
721
+ def custom_forward(*inputs):
722
+ return module(*inputs)
723
+
724
+ return custom_forward
725
+
726
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
727
+ else:
728
+ hidden_states = resnet(hidden_states, temb)
729
+
730
+ output_states += (hidden_states,)
731
+
732
+ if self.downsamplers is not None:
733
+ for downsampler in self.downsamplers:
734
+ hidden_states = downsampler(hidden_states)
735
+
736
+ output_states += (hidden_states,)
737
+
738
+ return hidden_states, output_states
739
+
740
+
741
+ class DownEncoderBlock2D(nn.Module):
742
+ def __init__(
743
+ self,
744
+ in_channels: int,
745
+ out_channels: int,
746
+ dropout: float = 0.0,
747
+ num_layers: int = 1,
748
+ resnet_eps: float = 1e-6,
749
+ resnet_time_scale_shift: str = "default",
750
+ resnet_act_fn: str = "swish",
751
+ resnet_groups: int = 32,
752
+ resnet_pre_norm: bool = True,
753
+ output_scale_factor=1.0,
754
+ add_downsample=True,
755
+ downsample_padding=1,
756
+ ):
757
+ super().__init__()
758
+ resnets = []
759
+
760
+ for i in range(num_layers):
761
+ in_channels = in_channels if i == 0 else out_channels
762
+ resnets.append(
763
+ ResnetBlock2D(
764
+ in_channels=in_channels,
765
+ out_channels=out_channels,
766
+ temb_channels=None,
767
+ eps=resnet_eps,
768
+ groups=resnet_groups,
769
+ dropout=dropout,
770
+ time_embedding_norm=resnet_time_scale_shift,
771
+ non_linearity=resnet_act_fn,
772
+ output_scale_factor=output_scale_factor,
773
+ pre_norm=resnet_pre_norm,
774
+ )
775
+ )
776
+
777
+ self.resnets = nn.ModuleList(resnets)
778
+
779
+ if add_downsample:
780
+ self.downsamplers = nn.ModuleList(
781
+ [
782
+ Downsample2D(
783
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
784
+ )
785
+ ]
786
+ )
787
+ else:
788
+ self.downsamplers = None
789
+
790
+ def forward(self, hidden_states):
791
+ for resnet in self.resnets:
792
+ hidden_states = resnet(hidden_states, temb=None)
793
+
794
+ if self.downsamplers is not None:
795
+ for downsampler in self.downsamplers:
796
+ hidden_states = downsampler(hidden_states)
797
+
798
+ return hidden_states
799
+
800
+
801
+ class AttnDownEncoderBlock2D(nn.Module):
802
+ def __init__(
803
+ self,
804
+ in_channels: int,
805
+ out_channels: int,
806
+ dropout: float = 0.0,
807
+ num_layers: int = 1,
808
+ resnet_eps: float = 1e-6,
809
+ resnet_time_scale_shift: str = "default",
810
+ resnet_act_fn: str = "swish",
811
+ resnet_groups: int = 32,
812
+ resnet_pre_norm: bool = True,
813
+ attn_num_head_channels=1,
814
+ output_scale_factor=1.0,
815
+ add_downsample=True,
816
+ downsample_padding=1,
817
+ ):
818
+ super().__init__()
819
+ resnets = []
820
+ attentions = []
821
+
822
+ for i in range(num_layers):
823
+ in_channels = in_channels if i == 0 else out_channels
824
+ resnets.append(
825
+ ResnetBlock2D(
826
+ in_channels=in_channels,
827
+ out_channels=out_channels,
828
+ temb_channels=None,
829
+ eps=resnet_eps,
830
+ groups=resnet_groups,
831
+ dropout=dropout,
832
+ time_embedding_norm=resnet_time_scale_shift,
833
+ non_linearity=resnet_act_fn,
834
+ output_scale_factor=output_scale_factor,
835
+ pre_norm=resnet_pre_norm,
836
+ )
837
+ )
838
+ attentions.append(
839
+ AttentionBlock(
840
+ out_channels,
841
+ num_head_channels=attn_num_head_channels,
842
+ rescale_output_factor=output_scale_factor,
843
+ eps=resnet_eps,
844
+ norm_num_groups=resnet_groups,
845
+ )
846
+ )
847
+
848
+ self.attentions = nn.ModuleList(attentions)
849
+ self.resnets = nn.ModuleList(resnets)
850
+
851
+ if add_downsample:
852
+ self.downsamplers = nn.ModuleList(
853
+ [
854
+ Downsample2D(
855
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
856
+ )
857
+ ]
858
+ )
859
+ else:
860
+ self.downsamplers = None
861
+
862
+ def forward(self, hidden_states):
863
+ for resnet, attn in zip(self.resnets, self.attentions):
864
+ hidden_states = resnet(hidden_states, temb=None)
865
+ hidden_states = attn(hidden_states)
866
+
867
+ if self.downsamplers is not None:
868
+ for downsampler in self.downsamplers:
869
+ hidden_states = downsampler(hidden_states)
870
+
871
+ return hidden_states
872
+
873
+
874
+ class AttnSkipDownBlock2D(nn.Module):
875
+ def __init__(
876
+ self,
877
+ in_channels: int,
878
+ out_channels: int,
879
+ temb_channels: int,
880
+ dropout: float = 0.0,
881
+ num_layers: int = 1,
882
+ resnet_eps: float = 1e-6,
883
+ resnet_time_scale_shift: str = "default",
884
+ resnet_act_fn: str = "swish",
885
+ resnet_pre_norm: bool = True,
886
+ attn_num_head_channels=1,
887
+ attention_type="default",
888
+ output_scale_factor=np.sqrt(2.0),
889
+ downsample_padding=1,
890
+ add_downsample=True,
891
+ ):
892
+ super().__init__()
893
+ self.attentions = nn.ModuleList([])
894
+ self.resnets = nn.ModuleList([])
895
+
896
+ self.attention_type = attention_type
897
+
898
+ for i in range(num_layers):
899
+ in_channels = in_channels if i == 0 else out_channels
900
+ self.resnets.append(
901
+ ResnetBlock2D(
902
+ in_channels=in_channels,
903
+ out_channels=out_channels,
904
+ temb_channels=temb_channels,
905
+ eps=resnet_eps,
906
+ groups=min(in_channels // 4, 32),
907
+ groups_out=min(out_channels // 4, 32),
908
+ dropout=dropout,
909
+ time_embedding_norm=resnet_time_scale_shift,
910
+ non_linearity=resnet_act_fn,
911
+ output_scale_factor=output_scale_factor,
912
+ pre_norm=resnet_pre_norm,
913
+ )
914
+ )
915
+ self.attentions.append(
916
+ AttentionBlock(
917
+ out_channels,
918
+ num_head_channels=attn_num_head_channels,
919
+ rescale_output_factor=output_scale_factor,
920
+ eps=resnet_eps,
921
+ )
922
+ )
923
+
924
+ if add_downsample:
925
+ self.resnet_down = ResnetBlock2D(
926
+ in_channels=out_channels,
927
+ out_channels=out_channels,
928
+ temb_channels=temb_channels,
929
+ eps=resnet_eps,
930
+ groups=min(out_channels // 4, 32),
931
+ dropout=dropout,
932
+ time_embedding_norm=resnet_time_scale_shift,
933
+ non_linearity=resnet_act_fn,
934
+ output_scale_factor=output_scale_factor,
935
+ pre_norm=resnet_pre_norm,
936
+ use_in_shortcut=True,
937
+ down=True,
938
+ kernel="fir",
939
+ )
940
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
941
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
942
+ else:
943
+ self.resnet_down = None
944
+ self.downsamplers = None
945
+ self.skip_conv = None
946
+
947
+ def forward(self, hidden_states, temb=None, skip_sample=None):
948
+ output_states = ()
949
+
950
+ for resnet, attn in zip(self.resnets, self.attentions):
951
+ hidden_states = resnet(hidden_states, temb)
952
+ hidden_states = attn(hidden_states)
953
+ output_states += (hidden_states,)
954
+
955
+ if self.downsamplers is not None:
956
+ hidden_states = self.resnet_down(hidden_states, temb)
957
+ for downsampler in self.downsamplers:
958
+ skip_sample = downsampler(skip_sample)
959
+
960
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
961
+
962
+ output_states += (hidden_states,)
963
+
964
+ return hidden_states, output_states, skip_sample
965
+
966
+
967
+ class SkipDownBlock2D(nn.Module):
968
+ def __init__(
969
+ self,
970
+ in_channels: int,
971
+ out_channels: int,
972
+ temb_channels: int,
973
+ dropout: float = 0.0,
974
+ num_layers: int = 1,
975
+ resnet_eps: float = 1e-6,
976
+ resnet_time_scale_shift: str = "default",
977
+ resnet_act_fn: str = "swish",
978
+ resnet_pre_norm: bool = True,
979
+ output_scale_factor=np.sqrt(2.0),
980
+ add_downsample=True,
981
+ downsample_padding=1,
982
+ ):
983
+ super().__init__()
984
+ self.resnets = nn.ModuleList([])
985
+
986
+ for i in range(num_layers):
987
+ in_channels = in_channels if i == 0 else out_channels
988
+ self.resnets.append(
989
+ ResnetBlock2D(
990
+ in_channels=in_channels,
991
+ out_channels=out_channels,
992
+ temb_channels=temb_channels,
993
+ eps=resnet_eps,
994
+ groups=min(in_channels // 4, 32),
995
+ groups_out=min(out_channels // 4, 32),
996
+ dropout=dropout,
997
+ time_embedding_norm=resnet_time_scale_shift,
998
+ non_linearity=resnet_act_fn,
999
+ output_scale_factor=output_scale_factor,
1000
+ pre_norm=resnet_pre_norm,
1001
+ )
1002
+ )
1003
+
1004
+ if add_downsample:
1005
+ self.resnet_down = ResnetBlock2D(
1006
+ in_channels=out_channels,
1007
+ out_channels=out_channels,
1008
+ temb_channels=temb_channels,
1009
+ eps=resnet_eps,
1010
+ groups=min(out_channels // 4, 32),
1011
+ dropout=dropout,
1012
+ time_embedding_norm=resnet_time_scale_shift,
1013
+ non_linearity=resnet_act_fn,
1014
+ output_scale_factor=output_scale_factor,
1015
+ pre_norm=resnet_pre_norm,
1016
+ use_in_shortcut=True,
1017
+ down=True,
1018
+ kernel="fir",
1019
+ )
1020
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
1021
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
1022
+ else:
1023
+ self.resnet_down = None
1024
+ self.downsamplers = None
1025
+ self.skip_conv = None
1026
+
1027
+ def forward(self, hidden_states, temb=None, skip_sample=None):
1028
+ output_states = ()
1029
+
1030
+ for resnet in self.resnets:
1031
+ hidden_states = resnet(hidden_states, temb)
1032
+ output_states += (hidden_states,)
1033
+
1034
+ if self.downsamplers is not None:
1035
+ hidden_states = self.resnet_down(hidden_states, temb)
1036
+ for downsampler in self.downsamplers:
1037
+ skip_sample = downsampler(skip_sample)
1038
+
1039
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
1040
+
1041
+ output_states += (hidden_states,)
1042
+
1043
+ return hidden_states, output_states, skip_sample
1044
+
1045
+
1046
+ class AttnUpBlock2D(nn.Module):
1047
+ def __init__(
1048
+ self,
1049
+ in_channels: int,
1050
+ prev_output_channel: int,
1051
+ out_channels: int,
1052
+ temb_channels: int,
1053
+ dropout: float = 0.0,
1054
+ num_layers: int = 1,
1055
+ resnet_eps: float = 1e-6,
1056
+ resnet_time_scale_shift: str = "default",
1057
+ resnet_act_fn: str = "swish",
1058
+ resnet_groups: int = 32,
1059
+ resnet_pre_norm: bool = True,
1060
+ attention_type="default",
1061
+ attn_num_head_channels=1,
1062
+ output_scale_factor=1.0,
1063
+ add_upsample=True,
1064
+ ):
1065
+ super().__init__()
1066
+ resnets = []
1067
+ attentions = []
1068
+
1069
+ self.attention_type = attention_type
1070
+
1071
+ for i in range(num_layers):
1072
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1073
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1074
+
1075
+ resnets.append(
1076
+ ResnetBlock2D(
1077
+ in_channels=resnet_in_channels + res_skip_channels,
1078
+ out_channels=out_channels,
1079
+ temb_channels=temb_channels,
1080
+ eps=resnet_eps,
1081
+ groups=resnet_groups,
1082
+ dropout=dropout,
1083
+ time_embedding_norm=resnet_time_scale_shift,
1084
+ non_linearity=resnet_act_fn,
1085
+ output_scale_factor=output_scale_factor,
1086
+ pre_norm=resnet_pre_norm,
1087
+ )
1088
+ )
1089
+ attentions.append(
1090
+ AttentionBlock(
1091
+ out_channels,
1092
+ num_head_channels=attn_num_head_channels,
1093
+ rescale_output_factor=output_scale_factor,
1094
+ eps=resnet_eps,
1095
+ norm_num_groups=resnet_groups,
1096
+ )
1097
+ )
1098
+
1099
+ self.attentions = nn.ModuleList(attentions)
1100
+ self.resnets = nn.ModuleList(resnets)
1101
+
1102
+ if add_upsample:
1103
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1104
+ else:
1105
+ self.upsamplers = None
1106
+
1107
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1108
+ for resnet, attn in zip(self.resnets, self.attentions):
1109
+ # pop res hidden states
1110
+ res_hidden_states = res_hidden_states_tuple[-1]
1111
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1112
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1113
+
1114
+ hidden_states = resnet(hidden_states, temb)
1115
+ hidden_states = attn(hidden_states)
1116
+
1117
+ if self.upsamplers is not None:
1118
+ for upsampler in self.upsamplers:
1119
+ hidden_states = upsampler(hidden_states)
1120
+
1121
+ return hidden_states
1122
+
1123
+
1124
+ class CrossAttnUpBlock2D(nn.Module):
1125
+ def __init__(
1126
+ self,
1127
+ in_channels: int,
1128
+ out_channels: int,
1129
+ prev_output_channel: int,
1130
+ temb_channels: int,
1131
+ dropout: float = 0.0,
1132
+ num_layers: int = 1,
1133
+ resnet_eps: float = 1e-6,
1134
+ resnet_time_scale_shift: str = "default",
1135
+ resnet_act_fn: str = "swish",
1136
+ resnet_groups: int = 32,
1137
+ resnet_pre_norm: bool = True,
1138
+ attn_num_head_channels=1,
1139
+ cross_attention_dim=1280,
1140
+ attention_type="default",
1141
+ output_scale_factor=1.0,
1142
+ add_upsample=True,
1143
+ dual_cross_attention=False,
1144
+ use_linear_projection=False,
1145
+ only_cross_attention=False,
1146
+ ):
1147
+ super().__init__()
1148
+ resnets = []
1149
+ attentions = []
1150
+
1151
+ self.attention_type = attention_type
1152
+ self.attn_num_head_channels = attn_num_head_channels
1153
+
1154
+ for i in range(num_layers):
1155
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1156
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1157
+
1158
+ resnets.append(
1159
+ ResnetBlock2D(
1160
+ in_channels=resnet_in_channels + res_skip_channels,
1161
+ out_channels=out_channels,
1162
+ temb_channels=temb_channels,
1163
+ eps=resnet_eps,
1164
+ groups=resnet_groups,
1165
+ dropout=dropout,
1166
+ time_embedding_norm=resnet_time_scale_shift,
1167
+ non_linearity=resnet_act_fn,
1168
+ output_scale_factor=output_scale_factor,
1169
+ pre_norm=resnet_pre_norm,
1170
+ )
1171
+ )
1172
+ if not dual_cross_attention:
1173
+ attentions.append(
1174
+ Transformer2DModel(
1175
+ attn_num_head_channels,
1176
+ out_channels // attn_num_head_channels,
1177
+ in_channels=out_channels,
1178
+ num_layers=1,
1179
+ cross_attention_dim=cross_attention_dim,
1180
+ norm_num_groups=resnet_groups,
1181
+ use_linear_projection=use_linear_projection,
1182
+ only_cross_attention=only_cross_attention,
1183
+ )
1184
+ )
1185
+ else:
1186
+ attentions.append(
1187
+ DualTransformer2DModel(
1188
+ attn_num_head_channels,
1189
+ out_channels // attn_num_head_channels,
1190
+ in_channels=out_channels,
1191
+ num_layers=1,
1192
+ cross_attention_dim=cross_attention_dim,
1193
+ norm_num_groups=resnet_groups,
1194
+ )
1195
+ )
1196
+ self.attentions = nn.ModuleList(attentions)
1197
+ self.resnets = nn.ModuleList(resnets)
1198
+
1199
+ if add_upsample:
1200
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1201
+ else:
1202
+ self.upsamplers = None
1203
+
1204
+ self.gradient_checkpointing = False
1205
+
1206
+ def set_attention_slice(self, slice_size):
1207
+ head_dims = self.attn_num_head_channels
1208
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
1209
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
1210
+ raise ValueError(
1211
+ f"Make sure slice_size {slice_size} is a common divisor of "
1212
+ f"the number of heads used in cross_attention: {head_dims}"
1213
+ )
1214
+ if slice_size is not None and slice_size > min(head_dims):
1215
+ raise ValueError(
1216
+ f"slice_size {slice_size} has to be smaller or equal to "
1217
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
1218
+ )
1219
+
1220
+ for attn in self.attentions:
1221
+ attn._set_attention_slice(slice_size)
1222
+
1223
+ self.gradient_checkpointing = False
1224
+
1225
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
1226
+ for attn in self.attentions:
1227
+ attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
1228
+
1229
+ def forward(
1230
+ self,
1231
+ hidden_states,
1232
+ res_hidden_states_tuple,
1233
+ temb=None,
1234
+ encoder_hidden_states=None,
1235
+ upsample_size=None,
1236
+ text_format_dict={}
1237
+ ):
1238
+ for resnet, attn in zip(self.resnets, self.attentions):
1239
+ # pop res hidden states
1240
+ res_hidden_states = res_hidden_states_tuple[-1]
1241
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1242
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1243
+
1244
+ if self.training and self.gradient_checkpointing:
1245
+
1246
+ def create_custom_forward(module, return_dict=None):
1247
+ def custom_forward(*inputs):
1248
+ if return_dict is not None:
1249
+ return module(*inputs, return_dict=return_dict)
1250
+ else:
1251
+ return module(*inputs)
1252
+
1253
+ return custom_forward
1254
+
1255
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1256
+ hidden_states = torch.utils.checkpoint.checkpoint(
1257
+ create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states,
1258
+ text_format_dict
1259
+ )[0]
1260
+ else:
1261
+ hidden_states = resnet(hidden_states, temb)
1262
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
1263
+ text_format_dict=text_format_dict).sample
1264
+
1265
+ if self.upsamplers is not None:
1266
+ for upsampler in self.upsamplers:
1267
+ hidden_states = upsampler(hidden_states, upsample_size)
1268
+
1269
+ return hidden_states
1270
+
1271
+
1272
+ class UpBlock2D(nn.Module):
1273
+ def __init__(
1274
+ self,
1275
+ in_channels: int,
1276
+ prev_output_channel: int,
1277
+ out_channels: int,
1278
+ temb_channels: int,
1279
+ dropout: float = 0.0,
1280
+ num_layers: int = 1,
1281
+ resnet_eps: float = 1e-6,
1282
+ resnet_time_scale_shift: str = "default",
1283
+ resnet_act_fn: str = "swish",
1284
+ resnet_groups: int = 32,
1285
+ resnet_pre_norm: bool = True,
1286
+ output_scale_factor=1.0,
1287
+ add_upsample=True,
1288
+ ):
1289
+ super().__init__()
1290
+ resnets = []
1291
+
1292
+ for i in range(num_layers):
1293
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1294
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1295
+
1296
+ resnets.append(
1297
+ ResnetBlock2D(
1298
+ in_channels=resnet_in_channels + res_skip_channels,
1299
+ out_channels=out_channels,
1300
+ temb_channels=temb_channels,
1301
+ eps=resnet_eps,
1302
+ groups=resnet_groups,
1303
+ dropout=dropout,
1304
+ time_embedding_norm=resnet_time_scale_shift,
1305
+ non_linearity=resnet_act_fn,
1306
+ output_scale_factor=output_scale_factor,
1307
+ pre_norm=resnet_pre_norm,
1308
+ )
1309
+ )
1310
+
1311
+ self.resnets = nn.ModuleList(resnets)
1312
+
1313
+ if add_upsample:
1314
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1315
+ else:
1316
+ self.upsamplers = None
1317
+
1318
+ self.gradient_checkpointing = False
1319
+
1320
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
1321
+ for resnet in self.resnets:
1322
+ # pop res hidden states
1323
+ res_hidden_states = res_hidden_states_tuple[-1]
1324
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1325
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1326
+
1327
+ if self.training and self.gradient_checkpointing:
1328
+
1329
+ def create_custom_forward(module):
1330
+ def custom_forward(*inputs):
1331
+ return module(*inputs)
1332
+
1333
+ return custom_forward
1334
+
1335
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
1336
+ else:
1337
+ hidden_states = resnet(hidden_states, temb)
1338
+
1339
+ if self.upsamplers is not None:
1340
+ for upsampler in self.upsamplers:
1341
+ hidden_states = upsampler(hidden_states, upsample_size)
1342
+
1343
+ return hidden_states
1344
+
1345
+
1346
+ class UpDecoderBlock2D(nn.Module):
1347
+ def __init__(
1348
+ self,
1349
+ in_channels: int,
1350
+ out_channels: int,
1351
+ dropout: float = 0.0,
1352
+ num_layers: int = 1,
1353
+ resnet_eps: float = 1e-6,
1354
+ resnet_time_scale_shift: str = "default",
1355
+ resnet_act_fn: str = "swish",
1356
+ resnet_groups: int = 32,
1357
+ resnet_pre_norm: bool = True,
1358
+ output_scale_factor=1.0,
1359
+ add_upsample=True,
1360
+ ):
1361
+ super().__init__()
1362
+ resnets = []
1363
+
1364
+ for i in range(num_layers):
1365
+ input_channels = in_channels if i == 0 else out_channels
1366
+
1367
+ resnets.append(
1368
+ ResnetBlock2D(
1369
+ in_channels=input_channels,
1370
+ out_channels=out_channels,
1371
+ temb_channels=None,
1372
+ eps=resnet_eps,
1373
+ groups=resnet_groups,
1374
+ dropout=dropout,
1375
+ time_embedding_norm=resnet_time_scale_shift,
1376
+ non_linearity=resnet_act_fn,
1377
+ output_scale_factor=output_scale_factor,
1378
+ pre_norm=resnet_pre_norm,
1379
+ )
1380
+ )
1381
+
1382
+ self.resnets = nn.ModuleList(resnets)
1383
+
1384
+ if add_upsample:
1385
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1386
+ else:
1387
+ self.upsamplers = None
1388
+
1389
+ def forward(self, hidden_states):
1390
+ for resnet in self.resnets:
1391
+ hidden_states = resnet(hidden_states, temb=None)
1392
+
1393
+ if self.upsamplers is not None:
1394
+ for upsampler in self.upsamplers:
1395
+ hidden_states = upsampler(hidden_states)
1396
+
1397
+ return hidden_states
1398
+
1399
+
1400
+ class AttnUpDecoderBlock2D(nn.Module):
1401
+ def __init__(
1402
+ self,
1403
+ in_channels: int,
1404
+ out_channels: int,
1405
+ dropout: float = 0.0,
1406
+ num_layers: int = 1,
1407
+ resnet_eps: float = 1e-6,
1408
+ resnet_time_scale_shift: str = "default",
1409
+ resnet_act_fn: str = "swish",
1410
+ resnet_groups: int = 32,
1411
+ resnet_pre_norm: bool = True,
1412
+ attn_num_head_channels=1,
1413
+ output_scale_factor=1.0,
1414
+ add_upsample=True,
1415
+ ):
1416
+ super().__init__()
1417
+ resnets = []
1418
+ attentions = []
1419
+
1420
+ for i in range(num_layers):
1421
+ input_channels = in_channels if i == 0 else out_channels
1422
+
1423
+ resnets.append(
1424
+ ResnetBlock2D(
1425
+ in_channels=input_channels,
1426
+ out_channels=out_channels,
1427
+ temb_channels=None,
1428
+ eps=resnet_eps,
1429
+ groups=resnet_groups,
1430
+ dropout=dropout,
1431
+ time_embedding_norm=resnet_time_scale_shift,
1432
+ non_linearity=resnet_act_fn,
1433
+ output_scale_factor=output_scale_factor,
1434
+ pre_norm=resnet_pre_norm,
1435
+ )
1436
+ )
1437
+ attentions.append(
1438
+ AttentionBlock(
1439
+ out_channels,
1440
+ num_head_channels=attn_num_head_channels,
1441
+ rescale_output_factor=output_scale_factor,
1442
+ eps=resnet_eps,
1443
+ norm_num_groups=resnet_groups,
1444
+ )
1445
+ )
1446
+
1447
+ self.attentions = nn.ModuleList(attentions)
1448
+ self.resnets = nn.ModuleList(resnets)
1449
+
1450
+ if add_upsample:
1451
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1452
+ else:
1453
+ self.upsamplers = None
1454
+
1455
+ def forward(self, hidden_states):
1456
+ for resnet, attn in zip(self.resnets, self.attentions):
1457
+ hidden_states = resnet(hidden_states, temb=None)
1458
+ hidden_states = attn(hidden_states)
1459
+
1460
+ if self.upsamplers is not None:
1461
+ for upsampler in self.upsamplers:
1462
+ hidden_states = upsampler(hidden_states)
1463
+
1464
+ return hidden_states
1465
+
1466
+
1467
+ class AttnSkipUpBlock2D(nn.Module):
1468
+ def __init__(
1469
+ self,
1470
+ in_channels: int,
1471
+ prev_output_channel: int,
1472
+ out_channels: int,
1473
+ temb_channels: int,
1474
+ dropout: float = 0.0,
1475
+ num_layers: int = 1,
1476
+ resnet_eps: float = 1e-6,
1477
+ resnet_time_scale_shift: str = "default",
1478
+ resnet_act_fn: str = "swish",
1479
+ resnet_pre_norm: bool = True,
1480
+ attn_num_head_channels=1,
1481
+ attention_type="default",
1482
+ output_scale_factor=np.sqrt(2.0),
1483
+ upsample_padding=1,
1484
+ add_upsample=True,
1485
+ ):
1486
+ super().__init__()
1487
+ self.attentions = nn.ModuleList([])
1488
+ self.resnets = nn.ModuleList([])
1489
+
1490
+ self.attention_type = attention_type
1491
+
1492
+ for i in range(num_layers):
1493
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1494
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1495
+
1496
+ self.resnets.append(
1497
+ ResnetBlock2D(
1498
+ in_channels=resnet_in_channels + res_skip_channels,
1499
+ out_channels=out_channels,
1500
+ temb_channels=temb_channels,
1501
+ eps=resnet_eps,
1502
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1503
+ groups_out=min(out_channels // 4, 32),
1504
+ dropout=dropout,
1505
+ time_embedding_norm=resnet_time_scale_shift,
1506
+ non_linearity=resnet_act_fn,
1507
+ output_scale_factor=output_scale_factor,
1508
+ pre_norm=resnet_pre_norm,
1509
+ )
1510
+ )
1511
+
1512
+ self.attentions.append(
1513
+ AttentionBlock(
1514
+ out_channels,
1515
+ num_head_channels=attn_num_head_channels,
1516
+ rescale_output_factor=output_scale_factor,
1517
+ eps=resnet_eps,
1518
+ )
1519
+ )
1520
+
1521
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1522
+ if add_upsample:
1523
+ self.resnet_up = ResnetBlock2D(
1524
+ in_channels=out_channels,
1525
+ out_channels=out_channels,
1526
+ temb_channels=temb_channels,
1527
+ eps=resnet_eps,
1528
+ groups=min(out_channels // 4, 32),
1529
+ groups_out=min(out_channels // 4, 32),
1530
+ dropout=dropout,
1531
+ time_embedding_norm=resnet_time_scale_shift,
1532
+ non_linearity=resnet_act_fn,
1533
+ output_scale_factor=output_scale_factor,
1534
+ pre_norm=resnet_pre_norm,
1535
+ use_in_shortcut=True,
1536
+ up=True,
1537
+ kernel="fir",
1538
+ )
1539
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1540
+ self.skip_norm = torch.nn.GroupNorm(
1541
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1542
+ )
1543
+ self.act = nn.SiLU()
1544
+ else:
1545
+ self.resnet_up = None
1546
+ self.skip_conv = None
1547
+ self.skip_norm = None
1548
+ self.act = None
1549
+
1550
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1551
+ for resnet in self.resnets:
1552
+ # pop res hidden states
1553
+ res_hidden_states = res_hidden_states_tuple[-1]
1554
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1555
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1556
+
1557
+ hidden_states = resnet(hidden_states, temb)
1558
+
1559
+ hidden_states = self.attentions[0](hidden_states)
1560
+
1561
+ if skip_sample is not None:
1562
+ skip_sample = self.upsampler(skip_sample)
1563
+ else:
1564
+ skip_sample = 0
1565
+
1566
+ if self.resnet_up is not None:
1567
+ skip_sample_states = self.skip_norm(hidden_states)
1568
+ skip_sample_states = self.act(skip_sample_states)
1569
+ skip_sample_states = self.skip_conv(skip_sample_states)
1570
+
1571
+ skip_sample = skip_sample + skip_sample_states
1572
+
1573
+ hidden_states = self.resnet_up(hidden_states, temb)
1574
+
1575
+ return hidden_states, skip_sample
1576
+
1577
+
1578
+ class SkipUpBlock2D(nn.Module):
1579
+ def __init__(
1580
+ self,
1581
+ in_channels: int,
1582
+ prev_output_channel: int,
1583
+ out_channels: int,
1584
+ temb_channels: int,
1585
+ dropout: float = 0.0,
1586
+ num_layers: int = 1,
1587
+ resnet_eps: float = 1e-6,
1588
+ resnet_time_scale_shift: str = "default",
1589
+ resnet_act_fn: str = "swish",
1590
+ resnet_pre_norm: bool = True,
1591
+ output_scale_factor=np.sqrt(2.0),
1592
+ add_upsample=True,
1593
+ upsample_padding=1,
1594
+ ):
1595
+ super().__init__()
1596
+ self.resnets = nn.ModuleList([])
1597
+
1598
+ for i in range(num_layers):
1599
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1600
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1601
+
1602
+ self.resnets.append(
1603
+ ResnetBlock2D(
1604
+ in_channels=resnet_in_channels + res_skip_channels,
1605
+ out_channels=out_channels,
1606
+ temb_channels=temb_channels,
1607
+ eps=resnet_eps,
1608
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1609
+ groups_out=min(out_channels // 4, 32),
1610
+ dropout=dropout,
1611
+ time_embedding_norm=resnet_time_scale_shift,
1612
+ non_linearity=resnet_act_fn,
1613
+ output_scale_factor=output_scale_factor,
1614
+ pre_norm=resnet_pre_norm,
1615
+ )
1616
+ )
1617
+
1618
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1619
+ if add_upsample:
1620
+ self.resnet_up = ResnetBlock2D(
1621
+ in_channels=out_channels,
1622
+ out_channels=out_channels,
1623
+ temb_channels=temb_channels,
1624
+ eps=resnet_eps,
1625
+ groups=min(out_channels // 4, 32),
1626
+ groups_out=min(out_channels // 4, 32),
1627
+ dropout=dropout,
1628
+ time_embedding_norm=resnet_time_scale_shift,
1629
+ non_linearity=resnet_act_fn,
1630
+ output_scale_factor=output_scale_factor,
1631
+ pre_norm=resnet_pre_norm,
1632
+ use_in_shortcut=True,
1633
+ up=True,
1634
+ kernel="fir",
1635
+ )
1636
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1637
+ self.skip_norm = torch.nn.GroupNorm(
1638
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1639
+ )
1640
+ self.act = nn.SiLU()
1641
+ else:
1642
+ self.resnet_up = None
1643
+ self.skip_conv = None
1644
+ self.skip_norm = None
1645
+ self.act = None
1646
+
1647
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1648
+ for resnet in self.resnets:
1649
+ # pop res hidden states
1650
+ res_hidden_states = res_hidden_states_tuple[-1]
1651
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1652
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1653
+
1654
+ hidden_states = resnet(hidden_states, temb)
1655
+
1656
+ if skip_sample is not None:
1657
+ skip_sample = self.upsampler(skip_sample)
1658
+ else:
1659
+ skip_sample = 0
1660
+
1661
+ if self.resnet_up is not None:
1662
+ skip_sample_states = self.skip_norm(hidden_states)
1663
+ skip_sample_states = self.act(skip_sample_states)
1664
+ skip_sample_states = self.skip_conv(skip_sample_states)
1665
+
1666
+ skip_sample = skip_sample + skip_sample_states
1667
+
1668
+ hidden_states = self.resnet_up(hidden_states, temb)
1669
+
1670
+ return hidden_states, skip_sample
models/unet_2d_condition.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.utils import BaseOutput, logging
24
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
25
+ from .unet_2d_blocks import (
26
+ CrossAttnDownBlock2D,
27
+ CrossAttnUpBlock2D,
28
+ DownBlock2D,
29
+ UNetMidBlock2DCrossAttn,
30
+ UpBlock2D,
31
+ get_down_block,
32
+ get_up_block,
33
+ )
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ @dataclass
40
+ class UNet2DConditionOutput(BaseOutput):
41
+ """
42
+ Args:
43
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
44
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
45
+ """
46
+
47
+ sample: torch.FloatTensor
48
+
49
+
50
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
51
+ r"""
52
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
53
+ and returns sample shaped output.
54
+
55
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
56
+ implements for all the models (such as downloading or saving, etc.)
57
+
58
+ Parameters:
59
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
60
+ Height and width of input/output sample.
61
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
62
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
63
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
65
+ Whether to flip the sin to cos in the time embedding.
66
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
67
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
68
+ The tuple of downsample blocks to use.
69
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
70
+ The tuple of upsample blocks to use.
71
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
72
+ The tuple of output channels for each block.
73
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
74
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
75
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
76
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
77
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
78
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
79
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
80
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
81
+ """
82
+
83
+ _supports_gradient_checkpointing = True
84
+
85
+ @register_to_config
86
+ def __init__(
87
+ self,
88
+ sample_size: Optional[int] = None,
89
+ in_channels: int = 4,
90
+ out_channels: int = 4,
91
+ center_input_sample: bool = False,
92
+ flip_sin_to_cos: bool = True,
93
+ freq_shift: int = 0,
94
+ down_block_types: Tuple[str] = (
95
+ "CrossAttnDownBlock2D",
96
+ "CrossAttnDownBlock2D",
97
+ "CrossAttnDownBlock2D",
98
+ "DownBlock2D",
99
+ ),
100
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
102
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
103
+ layers_per_block: int = 2,
104
+ downsample_padding: int = 1,
105
+ mid_block_scale_factor: float = 1,
106
+ act_fn: str = "silu",
107
+ norm_num_groups: int = 32,
108
+ norm_eps: float = 1e-5,
109
+ cross_attention_dim: int = 1280,
110
+ attention_head_dim: Union[int, Tuple[int]] = 8,
111
+ dual_cross_attention: bool = False,
112
+ use_linear_projection: bool = False,
113
+ num_class_embeds: Optional[int] = None,
114
+ ):
115
+ super().__init__()
116
+
117
+ self.sample_size = sample_size
118
+ time_embed_dim = block_out_channels[0] * 4
119
+ # import ipdb;ipdb.set_trace()
120
+
121
+ # input
122
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
123
+
124
+ # time
125
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
126
+ timestep_input_dim = block_out_channels[0]
127
+
128
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
129
+
130
+ # class embedding
131
+ if num_class_embeds is not None:
132
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
133
+
134
+ self.down_blocks = nn.ModuleList([])
135
+ self.mid_block = None
136
+ self.up_blocks = nn.ModuleList([])
137
+
138
+ if isinstance(only_cross_attention, bool):
139
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
140
+
141
+ if isinstance(attention_head_dim, int):
142
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
143
+
144
+ # down
145
+ output_channel = block_out_channels[0]
146
+ for i, down_block_type in enumerate(down_block_types):
147
+ input_channel = output_channel
148
+ output_channel = block_out_channels[i]
149
+ is_final_block = i == len(block_out_channels) - 1
150
+
151
+ down_block = get_down_block(
152
+ down_block_type,
153
+ num_layers=layers_per_block,
154
+ in_channels=input_channel,
155
+ out_channels=output_channel,
156
+ temb_channels=time_embed_dim,
157
+ add_downsample=not is_final_block,
158
+ resnet_eps=norm_eps,
159
+ resnet_act_fn=act_fn,
160
+ resnet_groups=norm_num_groups,
161
+ cross_attention_dim=cross_attention_dim,
162
+ attn_num_head_channels=attention_head_dim[i],
163
+ downsample_padding=downsample_padding,
164
+ dual_cross_attention=dual_cross_attention,
165
+ use_linear_projection=use_linear_projection,
166
+ only_cross_attention=only_cross_attention[i],
167
+ )
168
+ self.down_blocks.append(down_block)
169
+
170
+ # mid
171
+ self.mid_block = UNetMidBlock2DCrossAttn(
172
+ in_channels=block_out_channels[-1],
173
+ temb_channels=time_embed_dim,
174
+ resnet_eps=norm_eps,
175
+ resnet_act_fn=act_fn,
176
+ output_scale_factor=mid_block_scale_factor,
177
+ resnet_time_scale_shift="default",
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[-1],
180
+ resnet_groups=norm_num_groups,
181
+ dual_cross_attention=dual_cross_attention,
182
+ use_linear_projection=use_linear_projection,
183
+ )
184
+
185
+ # count how many layers upsample the images
186
+ self.num_upsamplers = 0
187
+
188
+ # up
189
+ reversed_block_out_channels = list(reversed(block_out_channels))
190
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
191
+ only_cross_attention = list(reversed(only_cross_attention))
192
+ output_channel = reversed_block_out_channels[0]
193
+ for i, up_block_type in enumerate(up_block_types):
194
+ is_final_block = i == len(block_out_channels) - 1
195
+
196
+ prev_output_channel = output_channel
197
+ output_channel = reversed_block_out_channels[i]
198
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
199
+
200
+ # add upsample block for all BUT final layer
201
+ if not is_final_block:
202
+ add_upsample = True
203
+ self.num_upsamplers += 1
204
+ else:
205
+ add_upsample = False
206
+
207
+ up_block = get_up_block(
208
+ up_block_type,
209
+ num_layers=layers_per_block + 1,
210
+ in_channels=input_channel,
211
+ out_channels=output_channel,
212
+ prev_output_channel=prev_output_channel,
213
+ temb_channels=time_embed_dim,
214
+ add_upsample=add_upsample,
215
+ resnet_eps=norm_eps,
216
+ resnet_act_fn=act_fn,
217
+ resnet_groups=norm_num_groups,
218
+ cross_attention_dim=cross_attention_dim,
219
+ attn_num_head_channels=reversed_attention_head_dim[i],
220
+ dual_cross_attention=dual_cross_attention,
221
+ use_linear_projection=use_linear_projection,
222
+ only_cross_attention=only_cross_attention[i],
223
+ )
224
+ self.up_blocks.append(up_block)
225
+ prev_output_channel = output_channel
226
+
227
+ # out
228
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
229
+ self.conv_act = nn.SiLU()
230
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
231
+
232
+ def set_attention_slice(self, slice_size):
233
+ head_dims = self.config.attention_head_dim
234
+ head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
235
+ if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
236
+ raise ValueError(
237
+ f"Make sure slice_size {slice_size} is a common divisor of "
238
+ f"the number of heads used in cross_attention: {head_dims}"
239
+ )
240
+ if slice_size is not None and slice_size > min(head_dims):
241
+ raise ValueError(
242
+ f"slice_size {slice_size} has to be smaller or equal to "
243
+ f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
244
+ )
245
+
246
+ for block in self.down_blocks:
247
+ if hasattr(block, "attentions") and block.attentions is not None:
248
+ block.set_attention_slice(slice_size)
249
+
250
+ self.mid_block.set_attention_slice(slice_size)
251
+
252
+ for block in self.up_blocks:
253
+ if hasattr(block, "attentions") and block.attentions is not None:
254
+ block.set_attention_slice(slice_size)
255
+
256
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
257
+ for block in self.down_blocks:
258
+ if hasattr(block, "attentions") and block.attentions is not None:
259
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
260
+
261
+ self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
262
+
263
+ for block in self.up_blocks:
264
+ if hasattr(block, "attentions") and block.attentions is not None:
265
+ block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
266
+
267
+ def _set_gradient_checkpointing(self, module, value=False):
268
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
269
+ module.gradient_checkpointing = value
270
+
271
+ def forward(
272
+ self,
273
+ sample: torch.FloatTensor,
274
+ timestep: Union[torch.Tensor, float, int],
275
+ encoder_hidden_states: torch.Tensor,
276
+ class_labels: Optional[torch.Tensor] = None,
277
+ text_format_dict = {},
278
+ return_dict: bool = True,
279
+ ) -> Union[UNet2DConditionOutput, Tuple]:
280
+ r"""
281
+ Args:
282
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
283
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
284
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
285
+ return_dict (`bool`, *optional*, defaults to `True`):
286
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
287
+
288
+ Returns:
289
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
290
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
291
+ returning a tuple, the first element is the sample tensor.
292
+ """
293
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
294
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
295
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
296
+ # on the fly if necessary.
297
+ default_overall_up_factor = 2**self.num_upsamplers
298
+
299
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
300
+ forward_upsample_size = False
301
+ upsample_size = None
302
+
303
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
304
+ logger.info("Forward upsample size to force interpolation output size.")
305
+ forward_upsample_size = True
306
+
307
+ # 0. center input if necessary
308
+ if self.config.center_input_sample:
309
+ sample = 2 * sample - 1.0
310
+
311
+ # 1. time
312
+ timesteps = timestep
313
+ if not torch.is_tensor(timesteps):
314
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
315
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
316
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
317
+ timesteps = timesteps[None].to(sample.device)
318
+
319
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
320
+ timesteps = timesteps.expand(sample.shape[0])
321
+
322
+ t_emb = self.time_proj(timesteps)
323
+
324
+ # timesteps does not contain any weights and will always return f32 tensors
325
+ # but time_embedding might actually be running in fp16. so we need to cast here.
326
+ # there might be better ways to encapsulate this.
327
+ t_emb = t_emb.to(dtype=self.dtype)
328
+ emb = self.time_embedding(t_emb)
329
+
330
+ if self.config.num_class_embeds is not None:
331
+ if class_labels is None:
332
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
333
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
334
+ emb = emb + class_emb
335
+
336
+ # 2. pre-process
337
+ sample = self.conv_in(sample)
338
+
339
+ # 3. down
340
+ down_block_res_samples = (sample,)
341
+ for downsample_block in self.down_blocks:
342
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
343
+ if isinstance(downsample_block, CrossAttnDownBlock2D):
344
+ sample, res_samples = downsample_block(
345
+ hidden_states=sample,
346
+ temb=emb,
347
+ encoder_hidden_states=encoder_hidden_states,
348
+ text_format_dict=text_format_dict
349
+ )
350
+ else:
351
+ sample, res_samples = downsample_block(
352
+ hidden_states=sample,
353
+ temb=emb,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ )
356
+ else:
357
+ if isinstance(downsample_block, CrossAttnDownBlock2D):
358
+ import ipdb;ipdb.set_trace()
359
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
360
+ down_block_res_samples += res_samples
361
+
362
+ # 4. mid
363
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states,
364
+ text_format_dict=text_format_dict)
365
+
366
+ # 5. up
367
+ for i, upsample_block in enumerate(self.up_blocks):
368
+ is_final_block = i == len(self.up_blocks) - 1
369
+
370
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
371
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
372
+
373
+ # if we have not reached the final block and need to forward the
374
+ # upsample size, we do it here
375
+ if not is_final_block and forward_upsample_size:
376
+ upsample_size = down_block_res_samples[-1].shape[2:]
377
+
378
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
379
+ if isinstance(upsample_block, CrossAttnUpBlock2D):
380
+ sample = upsample_block(
381
+ hidden_states=sample,
382
+ temb=emb,
383
+ res_hidden_states_tuple=res_samples,
384
+ encoder_hidden_states=encoder_hidden_states,
385
+ upsample_size=upsample_size,
386
+ text_format_dict=text_format_dict
387
+ )
388
+ else:
389
+ sample = upsample_block(
390
+ hidden_states=sample,
391
+ temb=emb,
392
+ res_hidden_states_tuple=res_samples,
393
+ encoder_hidden_states=encoder_hidden_states,
394
+ upsample_size=upsample_size,
395
+ )
396
+ else:
397
+ if isinstance(upsample_block, CrossAttnUpBlock2D):
398
+ upsample_block.attentions
399
+ import ipdb;ipdb.set_trace()
400
+ sample = upsample_block(
401
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
402
+ )
403
+ # 6. post-process
404
+ sample = self.conv_norm_out(sample)
405
+ sample = self.conv_act(sample)
406
+ sample = self.conv_out(sample)
407
+
408
+ if not return_dict:
409
+ return (sample,)
410
+
411
+ return UNet2DConditionOutput(sample=sample)
sample.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import argparse
5
+ import imageio
6
+ import torch
7
+ import numpy as np
8
+ from torchvision import transforms
9
+
10
+ from models.region_diffusion import RegionDiffusion
11
+ from utils.attention_utils import get_token_maps
12
+ from utils.richtext_utils import seed_everything, parse_json, get_region_diffusion_input,\
13
+ get_attention_control_input, get_gradient_guidance_input
14
+
15
+
16
+ def main(args, param):
17
+
18
+ # Create the folder to store outputs.
19
+ run_dir = args.run_dir
20
+ os.makedirs(args.run_dir, exist_ok=True)
21
+
22
+ # Load region diffusion model.
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ model = RegionDiffusion(device)
25
+
26
+ # parse json to span attributes
27
+ base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
28
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance = parse_json(
29
+ param['text_input'])
30
+
31
+ # create control input for region diffusion
32
+ region_text_prompts, region_target_token_ids, base_tokens = get_region_diffusion_input(
33
+ model, base_text_prompt, style_text_prompts, footnote_text_prompts,
34
+ footnote_target_tokens, color_text_prompts, color_names)
35
+
36
+ # create control input for cross attention
37
+ text_format_dict = get_attention_control_input(
38
+ model, base_tokens, size_text_prompts_and_sizes)
39
+
40
+ # create control input for region guidance
41
+ text_format_dict, color_target_token_ids = get_gradient_guidance_input(
42
+ model, base_tokens, color_text_prompts, color_rgbs, text_format_dict)
43
+
44
+ height = param['height']
45
+ width = param['width']
46
+ seed = param['noise_index']
47
+ negative_text = param['negative_prompt']
48
+ seed_everything(seed)
49
+
50
+ # get token maps from plain text to image generation.
51
+ begin_time = time.time()
52
+ if model.attention_maps is None:
53
+ model.register_evaluation_hooks()
54
+ else:
55
+ model.reset_attention_maps()
56
+ plain_img = model.produce_attn_maps([base_text_prompt], [negative_text],
57
+ height=height, width=width, num_inference_steps=param['steps'],
58
+ guidance_scale=param['guidance_weight'])
59
+ fn_base = os.path.join(run_dir, 'seed%d_plain.png' % (seed))
60
+ imageio.imwrite(fn_base, plain_img[0])
61
+ print('time lapses to get attention maps: %.4f' % (time.time()-begin_time))
62
+ color_obj_masks = get_token_maps(
63
+ model.attention_maps, run_dir, width//8, height//8, color_target_token_ids, seed)
64
+ model.masks = get_token_maps(
65
+ model.attention_maps, run_dir, width//8, height//8, region_target_token_ids, seed, base_tokens)
66
+ color_obj_masks = [transforms.functional.resize(color_obj_mask, (height, width),
67
+ interpolation=transforms.InterpolationMode.BICUBIC,
68
+ antialias=True)
69
+ for color_obj_mask in color_obj_masks]
70
+ text_format_dict['color_obj_atten'] = color_obj_masks
71
+ model.remove_evaluation_hooks()
72
+
73
+ # generate image from rich text
74
+ begin_time = time.time()
75
+ seed_everything(seed)
76
+ rich_img = model.prompt_to_img(region_text_prompts, [negative_text],
77
+ height=height, width=width, num_inference_steps=param['steps'],
78
+ guidance_scale=param['guidance_weight'], use_grad_guidance=use_grad_guidance,
79
+ text_format_dict=text_format_dict)
80
+ print('time lapses to generate image from rich text: %.4f' %
81
+ (time.time()-begin_time))
82
+ fn_style = os.path.join(run_dir, 'seed%d_rich.png' % (seed))
83
+ imageio.imwrite(fn_style, rich_img[0])
84
+ # imageio.imwrite(fn_cat, np.concatenate([img[0], rich_img[0]], 1))
85
+
86
+
87
+ if __name__ == '__main__':
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument('--run_dir', type=str, default='results/release/debug')
90
+ parser.add_argument('--height', type=int, default=512)
91
+ parser.add_argument('--width', type=int, default=512)
92
+ parser.add_argument('--seed', type=int, default=6)
93
+ parser.add_argument('--sample_steps', type=int, default=41)
94
+ parser.add_argument('--rich_text_json', type=str,
95
+ default='{"ops":[{"insert":"A close-up 4k dslr photo of a "},{"attributes":{"link":"A cat wearing sunglasses and a bandana around its neck."},"insert":"cat"},{"insert":" riding a scooter. There are palm trees in the background."}]}')
96
+ parser.add_argument('--negative_prompt', type=str, default='')
97
+ parser.add_argument('--guidance_weight', type=float, default=8.5)
98
+ args = parser.parse_args()
99
+ param = {
100
+ 'text_input': json.loads(args.rich_text_json),
101
+ 'height': args.height,
102
+ 'width': args.width,
103
+ 'guidance_weight': args.guidance_weight,
104
+ 'steps': args.sample_steps,
105
+ 'noise_index': args.seed,
106
+ 'negative_prompt': args.negative_prompt,
107
+ }
108
+
109
+ main(args, param)
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/attention_utils.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import matplotlib as mpl
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import torch
7
+ import torchvision
8
+
9
+ from pathlib import Path
10
+ import skimage
11
+ from skimage.morphology import erosion, square
12
+
13
+
14
+ def split_attention_maps_over_steps(attention_maps):
15
+ r"""Function for splitting attention maps over steps.
16
+ Args:
17
+ attention_maps (dict): Dictionary of attention maps.
18
+ sampler_order (int): Order of the sampler.
19
+ """
20
+ # This function splits attention maps into unconditional and conditional score and over steps
21
+
22
+ attention_maps_cond = dict() # Maps corresponding to conditional score
23
+ attention_maps_uncond = dict() # Maps corresponding to unconditional score
24
+
25
+ for layer in attention_maps.keys():
26
+
27
+ for step_num in range(len(attention_maps[layer])):
28
+ if step_num not in attention_maps_cond:
29
+ attention_maps_cond[step_num] = dict()
30
+ attention_maps_uncond[step_num] = dict()
31
+
32
+ attention_maps_uncond[step_num].update(
33
+ {layer: attention_maps[layer][step_num][:1]})
34
+ attention_maps_cond[step_num].update(
35
+ {layer: attention_maps[layer][step_num][1:2]})
36
+
37
+ return attention_maps_cond, attention_maps_uncond
38
+
39
+
40
+ def plot_attention_maps(atten_map_list, obj_tokens, save_dir, seed, tokens_vis=None):
41
+ atten_names = ['presoftmax', 'postsoftmax', 'postsoftmax_erosion']
42
+ for i, (attn_map, obj_token) in enumerate(zip(atten_map_list, obj_tokens)):
43
+ n_obj = len(attn_map)
44
+ plt.figure()
45
+ plt.clf()
46
+
47
+ fig, axs = plt.subplots(
48
+ ncols=n_obj+1, gridspec_kw=dict(width_ratios=[1 for _ in range(n_obj)]+[0.1]))
49
+
50
+ fig.set_figheight(3)
51
+ fig.set_figwidth(3*n_obj+0.1)
52
+
53
+ cmap = plt.get_cmap('OrRd')
54
+
55
+ vmax = 0
56
+ vmin = 1
57
+ for tid in range(n_obj):
58
+ attention_map_cur = attn_map[tid]
59
+ vmax = max(vmax, float(attention_map_cur.max()))
60
+ vmin = min(vmin, float(attention_map_cur.min()))
61
+
62
+ for tid in range(n_obj):
63
+ sns.heatmap(
64
+ attn_map[tid][0], annot=False, cbar=False, ax=axs[tid],
65
+ cmap=cmap, vmin=vmin, vmax=vmax
66
+ )
67
+ axs[tid].set_axis_off()
68
+ if tokens_vis is not None:
69
+ if tid == n_obj-1:
70
+ axs_xlabel = 'other tokens'
71
+ else:
72
+ axs_xlabel = ''
73
+ for token_id in obj_tokens[tid]:
74
+ axs_xlabel += tokens_vis[token_id.item() -
75
+ 1][:-len('</w>')]
76
+ axs[tid].set_title(axs_xlabel)
77
+
78
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
79
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
80
+ fig.colorbar(sm, cax=axs[-1])
81
+
82
+ fig.tight_layout()
83
+ plt.savefig(os.path.join(
84
+ save_dir, 'token_mapes_seed%d_%s.png' % (seed, atten_names[i])), dpi=100)
85
+ plt.close('all')
86
+
87
+
88
+ def get_token_maps(attention_maps, save_dir, width, height, obj_tokens, seed=0, tokens_vis=None,
89
+ preprocess=False):
90
+ r"""Function to visualize attention maps.
91
+ Args:
92
+ save_dir (str): Path to save attention maps
93
+ batch_size (int): Batch size
94
+ sampler_order (int): Sampler order
95
+ """
96
+
97
+ # Split attention maps over steps
98
+ attention_maps_cond, _ = split_attention_maps_over_steps(
99
+ attention_maps
100
+ )
101
+
102
+ selected_layers = [
103
+ # 'down_blocks.0.attentions.0.transformer_blocks.0.attn2',
104
+ # 'down_blocks.0.attentions.1.transformer_blocks.0.attn2',
105
+ 'down_blocks.1.attentions.0.transformer_blocks.0.attn2',
106
+ # 'down_blocks.1.attentions.1.transformer_blocks.0.attn2',
107
+ 'down_blocks.2.attentions.0.transformer_blocks.0.attn2',
108
+ 'down_blocks.2.attentions.1.transformer_blocks.0.attn2',
109
+ 'mid_block.attentions.0.transformer_blocks.0.attn2',
110
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2',
111
+ 'up_blocks.1.attentions.1.transformer_blocks.0.attn2',
112
+ 'up_blocks.1.attentions.2.transformer_blocks.0.attn2',
113
+ # 'up_blocks.2.attentions.0.transformer_blocks.0.attn2',
114
+ 'up_blocks.2.attentions.1.transformer_blocks.0.attn2',
115
+ # 'up_blocks.2.attentions.2.transformer_blocks.0.attn2',
116
+ # 'up_blocks.3.attentions.0.transformer_blocks.0.attn2',
117
+ # 'up_blocks.3.attentions.1.transformer_blocks.0.attn2',
118
+ # 'up_blocks.3.attentions.2.transformer_blocks.0.attn2'
119
+ ]
120
+
121
+ nsteps = len(attention_maps_cond)
122
+ hw_ori = width * height
123
+
124
+ attention_maps = []
125
+ for obj_token in obj_tokens:
126
+ attention_maps.append([])
127
+
128
+ for step_num in range(nsteps):
129
+ attention_maps_cur = attention_maps_cond[step_num]
130
+
131
+ for layer in attention_maps_cur.keys():
132
+ if step_num < 10 or layer not in selected_layers:
133
+ continue
134
+
135
+ attention_ind = attention_maps_cur[layer].cpu()
136
+
137
+ # Attention maps are of shape [batch_size, nkeys, 77]
138
+ # since they are averaged out while collecting from hooks to save memory.
139
+ # Now split the heads from batch dimension
140
+ bs, hw, nclip = attention_ind.shape
141
+ down_ratio = np.sqrt(hw_ori // hw)
142
+ width_cur = int(width // down_ratio)
143
+ height_cur = int(height // down_ratio)
144
+ attention_ind = attention_ind.reshape(
145
+ bs, height_cur, width_cur, nclip)
146
+ for obj_id, obj_token in enumerate(obj_tokens):
147
+ if obj_token[0] == -1:
148
+ attention_map_prev = torch.stack(
149
+ [attention_maps[i][-1] for i in range(obj_id)]).sum(0)
150
+ attention_maps[obj_id].append(
151
+ attention_map_prev.max()-attention_map_prev)
152
+ else:
153
+ obj_attention_map = attention_ind[:, :, :, obj_token].max(-1, True)[
154
+ 0].permute([3, 0, 1, 2])
155
+ obj_attention_map = torchvision.transforms.functional.resize(obj_attention_map, (height, width),
156
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC, antialias=True)
157
+ attention_maps[obj_id].append(obj_attention_map)
158
+
159
+ # average attention maps over steps
160
+ attention_maps_averaged = []
161
+ for obj_id, obj_token in enumerate(obj_tokens):
162
+ if obj_id == len(obj_tokens) - 1:
163
+ attention_maps_averaged.append(
164
+ torch.cat(attention_maps[obj_id]).mean(0))
165
+ else:
166
+ attention_maps_averaged.append(
167
+ torch.cat(attention_maps[obj_id]).mean(0))
168
+
169
+ # normalize attention maps into [0, 1]
170
+ attention_maps_averaged_normalized = []
171
+ attention_maps_averaged_sum = torch.cat(attention_maps_averaged).sum(0)
172
+ for obj_id, obj_token in enumerate(obj_tokens):
173
+ attention_maps_averaged_normalized.append(
174
+ attention_maps_averaged[obj_id]/attention_maps_averaged_sum)
175
+
176
+ # softmax
177
+ attention_maps_averaged_normalized = (
178
+ torch.cat(attention_maps_averaged)/0.001).softmax(0)
179
+ attention_maps_averaged_normalized = [
180
+ attention_maps_averaged_normalized[i:i+1] for i in range(attention_maps_averaged_normalized.shape[0])]
181
+
182
+ if preprocess:
183
+ # it is possible to preprocess the attention maps here
184
+ selem = square(5)
185
+ attention_maps_averaged_eroded = [erosion(skimage.img_as_float(
186
+ map[0].numpy()*255), selem) for map in attention_maps_averaged_normalized[:2]]
187
+ attention_maps_averaged_eroded = [(torch.from_numpy(map).unsqueeze(
188
+ 0)/255. > 0.8).float() for map in attention_maps_averaged_eroded]
189
+ attention_maps_averaged_eroded.append(
190
+ 1 - torch.cat(attention_maps_averaged_eroded).sum(0, True))
191
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized,
192
+ attention_maps_averaged_eroded], obj_tokens, save_dir, seed, tokens_vis)
193
+ attention_maps_averaged_eroded = [attn_mask.unsqueeze(1).repeat(
194
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_eroded]
195
+ return attention_maps_averaged_eroded
196
+ else:
197
+ plot_attention_maps([attention_maps_averaged, attention_maps_averaged_normalized],
198
+ obj_tokens, save_dir, seed, tokens_vis)
199
+ attention_maps_averaged_normalized = [attn_mask.unsqueeze(1).repeat(
200
+ [1, 4, 1, 1]).cuda() for attn_mask in attention_maps_averaged_normalized]
201
+ return attention_maps_averaged_normalized
utils/richtext_utils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+
7
+ COLORS = {
8
+ 'brown': [165, 42, 42],
9
+ 'red': [255, 0, 0],
10
+ 'pink': [253, 108, 158],
11
+ 'orange': [255, 165, 0],
12
+ 'yellow': [255, 255, 0],
13
+ 'purple': [128, 0, 128],
14
+ 'green': [0, 128, 0],
15
+ 'blue': [0, 0, 255],
16
+ 'white': [255, 255, 255],
17
+ 'gray': [128, 128, 128],
18
+ 'black': [0, 0, 0],
19
+ }
20
+
21
+
22
+ def seed_everything(seed):
23
+ random.seed(seed)
24
+ os.environ['PYTHONHASHSEED'] = str(seed)
25
+ np.random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+
29
+
30
+ def hex_to_rgb(hex_string, return_nearest_color=False):
31
+ r"""
32
+ Covert Hex triplet to RGB triplet.
33
+ """
34
+ # Remove '#' symbol if present
35
+ hex_string = hex_string.lstrip('#')
36
+ # Convert hex values to integers
37
+ red = int(hex_string[0:2], 16)
38
+ green = int(hex_string[2:4], 16)
39
+ blue = int(hex_string[4:6], 16)
40
+ rgb = torch.FloatTensor((red, green, blue))[None, :, None, None]/255.
41
+ if return_nearest_color:
42
+ nearest_color = find_nearest_color(rgb)
43
+ return rgb.cuda(), nearest_color
44
+ return rgb.cuda()
45
+
46
+
47
+ def find_nearest_color(rgb):
48
+ r"""
49
+ Find the nearest neighbor color given the RGB value.
50
+ """
51
+ if isinstance(rgb, list) or isinstance(rgb, tuple):
52
+ rgb = torch.FloatTensor(rgb)[None, :, None, None]/255.
53
+ color_distance = torch.FloatTensor([np.linalg.norm(
54
+ rgb - torch.FloatTensor(COLORS[color])[None, :, None, None]/255.) for color in COLORS.keys()])
55
+ nearest_color = list(COLORS.keys())[torch.argmin(color_distance).item()]
56
+ return nearest_color
57
+
58
+
59
+ def font2style(font):
60
+ r"""
61
+ Convert the font name to the style name.
62
+ """
63
+ return {'mirza': 'Claud Monet, impressionism, oil on canvas',
64
+ 'roboto': 'Ukiyoe',
65
+ 'cursive': 'Cyber Punk, futuristic, blade runner, william gibson, trending on artstation hq',
66
+ 'sofia': 'Pop Art, masterpiece, andy warhol',
67
+ 'slabo': 'Vincent Van Gogh',
68
+ 'inconsolata': 'Pixel Art, 8 bits, 16 bits',
69
+ 'ubuntu': 'Rembrandt',
70
+ 'Monoton': 'neon art, colorful light, highly details, octane render',
71
+ 'Akronim': 'Abstract Cubism, Pablo Picasso', }[font]
72
+
73
+
74
+ def parse_json(json_str):
75
+ r"""
76
+ Convert the JSON string to attributes.
77
+ """
78
+ # initialze region-base attributes.
79
+ base_text_prompt = ''
80
+ style_text_prompts = []
81
+ footnote_text_prompts = []
82
+ footnote_target_tokens = []
83
+ color_text_prompts = []
84
+ color_rgbs = []
85
+ color_names = []
86
+ size_text_prompts_and_sizes = []
87
+
88
+ # parse the attributes from JSON.
89
+ prev_style = None
90
+ prev_color_rgb = None
91
+ use_grad_guidance = False
92
+ for span in json_str['ops']:
93
+ text_prompt = span['insert'].rstrip('\n')
94
+ base_text_prompt += span['insert'].rstrip('\n')
95
+ if text_prompt == ' ':
96
+ continue
97
+ if 'attributes' in span:
98
+ if 'font' in span['attributes']:
99
+ style = font2style(span['attributes']['font'])
100
+ if prev_style == style:
101
+ prev_text_prompt = style_text_prompts[-1].split('in the style of')[
102
+ 0]
103
+ style_text_prompts[-1] = prev_text_prompt + \
104
+ ' ' + text_prompt + f' in the style of {style}'
105
+ else:
106
+ style_text_prompts.append(
107
+ text_prompt + f' in the style of {style}')
108
+ prev_style = style
109
+ else:
110
+ prev_style = None
111
+ if 'link' in span['attributes']:
112
+ footnote_text_prompts.append(span['attributes']['link'])
113
+ footnote_target_tokens.append(text_prompt)
114
+ font_size = 1
115
+ if 'size' in span['attributes'] and 'strike' not in span['attributes']:
116
+ font_size = float(span['attributes']['size'][:-2])/3.
117
+ elif 'size' in span['attributes'] and 'strike' in span['attributes']:
118
+ font_size = -float(span['attributes']['size'][:-2])/3.
119
+ elif 'size' not in span['attributes'] and 'strike' not in span['attributes']:
120
+ font_size = 1
121
+ if 'color' in span['attributes']:
122
+ use_grad_guidance = True
123
+ color_rgb, nearest_color = hex_to_rgb(
124
+ span['attributes']['color'], True)
125
+ if prev_color_rgb == color_rgb:
126
+ prev_text_prompt = color_text_prompts[-1]
127
+ color_text_prompts[-1] = prev_text_prompt + \
128
+ ' ' + text_prompt
129
+ else:
130
+ color_rgbs.append(color_rgb)
131
+ color_names.append(nearest_color)
132
+ color_text_prompts.append(text_prompt)
133
+ if font_size != 1:
134
+ size_text_prompts_and_sizes.append([text_prompt, font_size])
135
+ return base_text_prompt, style_text_prompts, footnote_text_prompts, footnote_target_tokens,\
136
+ color_text_prompts, color_names, color_rgbs, size_text_prompts_and_sizes, use_grad_guidance
137
+
138
+
139
+ def get_region_diffusion_input(model, base_text_prompt, style_text_prompts, footnote_text_prompts,
140
+ footnote_target_tokens, color_text_prompts, color_names):
141
+ r"""
142
+ Algorithm 1 in the paper.
143
+ """
144
+ region_text_prompts = []
145
+ region_target_token_ids = []
146
+ base_tokens = model.tokenizer._tokenize(base_text_prompt)
147
+ # process the style text prompt
148
+ for text_prompt in style_text_prompts:
149
+ region_text_prompts.append(text_prompt)
150
+ region_target_token_ids.append([])
151
+ style_tokens = model.tokenizer._tokenize(
152
+ text_prompt.split('in the style of')[0])
153
+ for style_token in style_tokens:
154
+ region_target_token_ids[-1].append(
155
+ base_tokens.index(style_token)+1)
156
+
157
+ # process the complementary text prompt
158
+ for footnote_text_prompt, text_prompt in zip(footnote_text_prompts, footnote_target_tokens):
159
+ region_target_token_ids.append([])
160
+ region_text_prompts.append(footnote_text_prompt)
161
+ style_tokens = model.tokenizer._tokenize(text_prompt)
162
+ for style_token in style_tokens:
163
+ region_target_token_ids[-1].append(
164
+ base_tokens.index(style_token)+1)
165
+
166
+ # process the color text prompt
167
+ for color_text_prompt, color_name in zip(color_text_prompts, color_names):
168
+ region_target_token_ids.append([])
169
+ region_text_prompts.append(color_name+' '+color_text_prompt)
170
+ style_tokens = model.tokenizer._tokenize(color_text_prompt)
171
+ for style_token in style_tokens:
172
+ region_target_token_ids[-1].append(
173
+ base_tokens.index(style_token)+1)
174
+
175
+ # process the remaining tokens without any attributes
176
+ region_text_prompts.append(base_text_prompt)
177
+ region_target_token_ids_all = [
178
+ id for ids in region_target_token_ids for id in ids]
179
+ target_token_ids_rest = [id for id in range(
180
+ 1, len(base_tokens)+1) if id not in region_target_token_ids_all]
181
+ region_target_token_ids.append(target_token_ids_rest)
182
+
183
+ region_target_token_ids = [torch.LongTensor(
184
+ obj_token_id) for obj_token_id in region_target_token_ids]
185
+ return region_text_prompts, region_target_token_ids, base_tokens
186
+
187
+
188
+ def get_attention_control_input(model, base_tokens, size_text_prompts_and_sizes):
189
+ r"""
190
+ Control the token impact using font sizes.
191
+ """
192
+ word_pos = []
193
+ font_sizes = []
194
+ for text_prompt, font_size in size_text_prompts_and_sizes:
195
+ size_tokens = model.tokenizer._tokenize(text_prompt)
196
+ for size_token in size_tokens:
197
+ word_pos.append(base_tokens.index(size_token)+1)
198
+ font_sizes.append(font_size)
199
+ if len(word_pos) > 0:
200
+ word_pos = torch.LongTensor(word_pos).cuda()
201
+ font_sizes = torch.FloatTensor(font_sizes).cuda()
202
+ else:
203
+ word_pos = None
204
+ font_sizes = None
205
+ text_format_dict = {
206
+ 'word_pos': word_pos,
207
+ 'font_size': font_sizes,
208
+ }
209
+ return text_format_dict
210
+
211
+
212
+ def get_gradient_guidance_input(model, base_tokens, color_text_prompts, color_rgbs, text_format_dict,
213
+ guidance_start_step=999, color_guidance_weight=1):
214
+ r"""
215
+ Control the token impact using font sizes.
216
+ """
217
+ color_target_token_ids = []
218
+ for text_prompt in color_text_prompts:
219
+ color_target_token_ids.append([])
220
+ color_tokens = model.tokenizer._tokenize(text_prompt)
221
+ for color_token in color_tokens:
222
+ color_target_token_ids[-1].append(base_tokens.index(color_token)+1)
223
+ color_target_token_ids_all = [
224
+ id for ids in color_target_token_ids for id in ids]
225
+ color_target_token_ids_rest = [id for id in range(
226
+ 1, len(base_tokens)+1) if id not in color_target_token_ids_all]
227
+ color_target_token_ids.append(color_target_token_ids_rest)
228
+ color_target_token_ids = [torch.LongTensor(
229
+ obj_token_id) for obj_token_id in color_target_token_ids]
230
+
231
+ text_format_dict['target_RGB'] = color_rgbs
232
+ text_format_dict['guidance_start_step'] = guidance_start_step
233
+ text_format_dict['color_guidance_weight'] = color_guidance_weight
234
+ return text_format_dict, color_target_token_ids