Text-to-Image
daoyuan98 commited on
Commit
54c32af
·
verified ·
1 Parent(s): a6c206d

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +276 -0
model.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+ import torch
4
+
5
+ from torch import Tensor, nn
6
+ from einops import rearrange
7
+
8
+ from layers import (DoubleStreamBlock, EmbedND, LastLayer,
9
+ MLPEmbedder, SingleStreamBlock,
10
+ timestep_embedding)
11
+
12
+ import torch.distributed as dist
13
+ from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
14
+
15
+ from accelerate.logging import get_logger
16
+ logger = get_logger(__name__, log_level="INFO")
17
+
18
+
19
+
20
+
21
+ @dataclass
22
+ class FluxParams:
23
+ in_channels: int
24
+ vec_in_dim: int
25
+ context_in_dim: int
26
+ hidden_size: int
27
+ mlp_ratio: float
28
+ num_heads: int
29
+ depth: int
30
+ depth_single_blocks: int
31
+ axes_dim: list[int]
32
+ theta: int
33
+ qkv_bias: bool
34
+ guidance_embed: bool
35
+
36
+
37
+ class Flux(nn.Module):
38
+ """
39
+ Transformer model for flow matching on sequences.
40
+ """
41
+ _supports_gradient_checkpointing = True
42
+
43
+ def __init__(self, params: FluxParams):
44
+ super().__init__()
45
+
46
+ self.params = params
47
+ self.in_channels = params.in_channels
48
+ self.out_channels = self.in_channels
49
+ if params.hidden_size % params.num_heads != 0:
50
+ raise ValueError(
51
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
52
+ )
53
+ pe_dim = params.hidden_size // params.num_heads
54
+ if sum(params.axes_dim) != pe_dim:
55
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
56
+ self.hidden_size = params.hidden_size
57
+ self.num_heads = params.num_heads
58
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
59
+
60
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
61
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
62
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
63
+ self.guidance_in = (
64
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
65
+ )
66
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
67
+
68
+
69
+ self.double_blocks = nn.ModuleList(
70
+ [
71
+ DoubleStreamBlock(
72
+ self.hidden_size,
73
+ self.num_heads,
74
+ mlp_ratio=params.mlp_ratio,
75
+ qkv_bias=params.qkv_bias
76
+ )
77
+ for i in range(1, params.depth+1)
78
+ ]
79
+ )
80
+
81
+ self.single_blocks = nn.ModuleList(
82
+ [
83
+ SingleStreamBlock(
84
+ self.hidden_size,
85
+ self.num_heads,
86
+ mlp_ratio=params.mlp_ratio
87
+ )
88
+ for i in range(1, params.depth_single_blocks+1)
89
+ ]
90
+ )
91
+
92
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
93
+ self.gradient_checkpointing = True
94
+
95
+ def _set_gradient_checkpointing(self, module, value=False):
96
+ if hasattr(module, "gradient_checkpointing"):
97
+ module.gradient_checkpointing = value
98
+
99
+ @property
100
+ def attn_processors(self):
101
+ # set recursively
102
+ processors = {}
103
+
104
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
105
+ if hasattr(module, "set_processor"):
106
+ processors[f"{name}.processor"] = module.processor
107
+
108
+ for sub_name, child in module.named_children():
109
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
110
+
111
+ return processors
112
+
113
+ for name, module in self.named_children():
114
+ fn_recursive_add_processors(name, module, processors)
115
+
116
+ return processors
117
+
118
+ def set_attn_processor(self, processor):
119
+ r"""
120
+ Sets the attention processor to use to compute attention.
121
+
122
+ Parameters:
123
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
124
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
125
+ for **all** `Attention` layers.
126
+
127
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
128
+ processor. This is strongly recommended when setting trainable attention processors.
129
+
130
+ """
131
+ count = len(self.attn_processors.keys())
132
+
133
+ if isinstance(processor, dict) and len(processor) != count:
134
+ raise ValueError(
135
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
136
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
137
+ )
138
+
139
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
140
+ if hasattr(module, "set_processor"):
141
+ if not isinstance(processor, dict):
142
+ module.set_processor(processor)
143
+ else:
144
+ module.set_processor(processor.pop(f"{name}.processor"))
145
+
146
+ for sub_name, child in module.named_children():
147
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
148
+
149
+ for name, module in self.named_children():
150
+ fn_recursive_attn_processor(name, module, processor)
151
+
152
+ def forward(
153
+ self,
154
+ img: Tensor,
155
+ img_ids: Tensor,
156
+ txt: Tensor,
157
+ txt_ids: Tensor,
158
+ timesteps: Tensor,
159
+ y: Tensor,
160
+ block_controlnet_hidden_states=None,
161
+ guidance: Tensor = None,
162
+ image_proj: Tensor = None,
163
+ ip_scale: Tensor = 1.0,
164
+ return_intermediate: bool = False,
165
+ ):
166
+
167
+ if return_intermediate:
168
+ intermediate_double = []
169
+ intermediate_single = []
170
+
171
+ # running on sequences img
172
+ img = self.img_in(img)
173
+ vec = self.time_in(timestep_embedding(timesteps, 256))
174
+ if self.params.guidance_embed:
175
+ if guidance is None:
176
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
177
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
178
+ vec = vec + self.vector_in(y)
179
+ txt = self.txt_in(txt)
180
+
181
+ ids = torch.cat((txt_ids, img_ids), dim=1)
182
+ pe = self.pe_embedder(ids)
183
+
184
+ if block_controlnet_hidden_states is not None:
185
+ controlnet_depth = len(block_controlnet_hidden_states)
186
+
187
+
188
+ for index_block, block in enumerate(self.double_blocks):
189
+
190
+ if self.training and self.gradient_checkpointing:
191
+
192
+ def create_custom_forward(module, return_dict=None):
193
+ def custom_forward(*inputs):
194
+ if return_dict is not None:
195
+ return module(*inputs, return_dict=return_dict)
196
+ else:
197
+ return module(*inputs)
198
+
199
+ return custom_forward
200
+
201
+ img, txt = torch.utils.checkpoint.checkpoint(
202
+ create_custom_forward(block),
203
+ img,
204
+ txt,
205
+ vec,
206
+ pe,
207
+ image_proj,
208
+ ip_scale,
209
+ use_reentrant=False
210
+ )
211
+
212
+ else:
213
+ img, txt = block(
214
+ img=img,
215
+ txt=txt,
216
+ vec=vec,
217
+ pe=pe,
218
+ image_proj=image_proj,
219
+ ip_scale=ip_scale
220
+ )
221
+
222
+
223
+ if return_intermediate:
224
+ intermediate_double.append(
225
+ [img, txt]
226
+ )
227
+
228
+ if block_controlnet_hidden_states is not None:
229
+ img = img + block_controlnet_hidden_states[index_block % 2]
230
+
231
+ img = torch.cat((txt, img), dim=1)
232
+ txt_dim = txt.shape[1]
233
+ for index_block, block in enumerate(self.single_blocks):
234
+
235
+ if self.training and self.gradient_checkpointing:
236
+
237
+ def create_custom_forward(module, return_dict=None):
238
+ def custom_forward(*inputs):
239
+ if return_dict is not None:
240
+ return module(*inputs, return_dict=return_dict)
241
+ else:
242
+ return module(*inputs)
243
+
244
+ return custom_forward
245
+
246
+ # ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
247
+ img = torch.utils.checkpoint.checkpoint(
248
+ create_custom_forward(block),
249
+ img,
250
+ vec,
251
+ pe,
252
+ use_reentrant=False
253
+ )
254
+
255
+ else:
256
+ img = block(img, vec=vec, pe=pe)
257
+
258
+
259
+ # if return_intermediate:
260
+ img_ = img[:, txt.shape[1]:, ...]
261
+ txt_ = img[:, :txt.shape[1], ...]
262
+
263
+ if return_intermediate:
264
+ intermediate_single.append(
265
+ [img_, txt_]
266
+ )
267
+
268
+ img = torch.cat([txt_, img_], dim=1)
269
+
270
+ img = img[:, txt.shape[1] :, ...]
271
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
272
+ if return_intermediate:
273
+ return img, intermediate_double, intermediate_single
274
+ else:
275
+ return img
276
+