KBlueLeaf commited on
Commit
dc99abb
1 Parent(s): b21b792

Upload stable_cascade.py

Browse files
Files changed (1) hide show
  1. stable_cascade.py +1623 -0
stable_cascade.py ADDED
@@ -0,0 +1,1623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # コードは Stable Cascade からコピーし、一部修正しています。元ライセンスは MIT です。
2
+ # The code is copied from Stable Cascade and modified. The original license is MIT.
3
+ # https://github.com/Stability-AI/StableCascade
4
+
5
+ import math
6
+ from types import SimpleNamespace
7
+ from typing import List, Optional
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ import torchvision
13
+
14
+
15
+
16
+ def check_scale(tensor):
17
+ return torch.mean(torch.abs(tensor))
18
+
19
+
20
+ # region VectorQuantize
21
+
22
+ # from torchtools https://github.com/pabloppp/pytorch-tools
23
+ # 依存ライブラリを増やしたくないのでここにコピペ
24
+
25
+
26
+ class vector_quantize(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(ctx, x, codebook):
29
+ with torch.no_grad():
30
+ codebook_sqr = torch.sum(codebook**2, dim=1)
31
+ x_sqr = torch.sum(x**2, dim=1, keepdim=True)
32
+
33
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
34
+ _, indices = dist.min(dim=1)
35
+
36
+ ctx.save_for_backward(indices, codebook)
37
+ ctx.mark_non_differentiable(indices)
38
+
39
+ nn = torch.index_select(codebook, 0, indices)
40
+ return nn, indices
41
+
42
+ @staticmethod
43
+ def backward(ctx, grad_output, grad_indices):
44
+ grad_inputs, grad_codebook = None, None
45
+
46
+ if ctx.needs_input_grad[0]:
47
+ grad_inputs = grad_output.clone()
48
+ if ctx.needs_input_grad[1]:
49
+ # Gradient wrt. the codebook
50
+ indices, codebook = ctx.saved_tensors
51
+
52
+ grad_codebook = torch.zeros_like(codebook)
53
+ grad_codebook.index_add_(0, indices, grad_output)
54
+
55
+ return (grad_inputs, grad_codebook)
56
+
57
+
58
+ class VectorQuantize(nn.Module):
59
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
60
+ """
61
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
62
+ Returns one tensor containing the nearest neighbour embeddings to each of the inputs,
63
+ with the same size as the input, vq and commitment components for the loss as a tuple
64
+ in the second output and the indices of the quantized vectors in the third:
65
+ quantized, (vq_loss, commit_loss), indices
66
+ """
67
+ super(VectorQuantize, self).__init__()
68
+
69
+ self.codebook = nn.Embedding(k, embedding_size)
70
+ self.codebook.weight.data.uniform_(-1.0 / k, 1.0 / k)
71
+ self.vq = vector_quantize.apply
72
+
73
+ self.ema_decay = ema_decay
74
+ self.ema_loss = ema_loss
75
+ if ema_loss:
76
+ self.register_buffer("ema_element_count", torch.ones(k))
77
+ self.register_buffer("ema_weight_sum", torch.zeros_like(self.codebook.weight))
78
+
79
+ def _laplace_smoothing(self, x, epsilon):
80
+ n = torch.sum(x)
81
+ return (x + epsilon) / (n + x.size(0) * epsilon) * n
82
+
83
+ def _updateEMA(self, z_e_x, indices):
84
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
85
+ elem_count = mask.sum(dim=0)
86
+ weight_sum = torch.mm(mask.t(), z_e_x)
87
+
88
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1 - self.ema_decay) * elem_count)
89
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
90
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1 - self.ema_decay) * weight_sum)
91
+
92
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
93
+
94
+ def idx2vq(self, idx, dim=-1):
95
+ q_idx = self.codebook(idx)
96
+ if dim != -1:
97
+ q_idx = q_idx.movedim(-1, dim)
98
+ return q_idx
99
+
100
+ def forward(self, x, get_losses=True, dim=-1):
101
+ if dim != -1:
102
+ x = x.movedim(dim, -1)
103
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
104
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
105
+ vq_loss, commit_loss = None, None
106
+ if self.ema_loss and self.training:
107
+ self._updateEMA(z_e_x.detach(), indices.detach())
108
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
109
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
110
+ if get_losses:
111
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
112
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
113
+
114
+ z_q_x = z_q_x.view(x.shape)
115
+ if dim != -1:
116
+ z_q_x = z_q_x.movedim(-1, dim)
117
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
118
+
119
+
120
+ # endregion
121
+
122
+
123
+ class EfficientNetEncoder(nn.Module):
124
+ def __init__(self, c_latent=16):
125
+ super().__init__()
126
+ self.backbone = torchvision.models.efficientnet_v2_s(weights="DEFAULT").features.eval()
127
+ self.mapper = nn.Sequential(
128
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
129
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
130
+ )
131
+
132
+ def forward(self, x):
133
+ return self.mapper(self.backbone(x))
134
+
135
+ @property
136
+ def dtype(self) -> torch.dtype:
137
+ return next(self.parameters()).dtype
138
+
139
+ @property
140
+ def device(self) -> torch.device:
141
+ return next(self.parameters()).device
142
+
143
+ def encode(self, x):
144
+ """
145
+ VAE と同じように使えるようにするためのメソッド。正しくはちゃんと呼び出し側で分けるべきだが、暫定的な対応。
146
+ The method to make it usable like VAE. It should be separated properly, but it is a temporary response.
147
+ """
148
+ # latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
149
+ x = self(x)
150
+ return SimpleNamespace(latent_dist=SimpleNamespace(sample=lambda: x))
151
+
152
+
153
+ # なんかわりと乱暴な実装(;'∀')
154
+ # 一から学習することもないだろうから、無効化しておく
155
+
156
+ # class Linear(torch.nn.Linear):
157
+ # def reset_parameters(self):
158
+ # return None
159
+
160
+ # class Conv2d(torch.nn.Conv2d):
161
+ # def reset_parameters(self):
162
+ # return None
163
+
164
+ from torch.nn import Conv2d
165
+ from torch.nn import Linear
166
+
167
+
168
+ class Attention2D(nn.Module):
169
+ def __init__(self, c, nhead, dropout=0.0):
170
+ super().__init__()
171
+ self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True)
172
+
173
+ def forward(self, x, kv, self_attn=False):
174
+ orig_shape = x.shape
175
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
176
+ if self_attn:
177
+ kv = torch.cat([x, kv], dim=1)
178
+ x = self.attn(x, kv, kv, need_weights=False)[0]
179
+ x = x.permute(0, 2, 1).view(*orig_shape)
180
+ return x
181
+
182
+
183
+ class LayerNorm2d(nn.LayerNorm):
184
+ def __init__(self, *args, **kwargs):
185
+ super().__init__(*args, **kwargs)
186
+
187
+ def forward(self, x):
188
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
189
+
190
+
191
+ class GlobalResponseNorm(nn.Module):
192
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
193
+
194
+ def __init__(self, dim):
195
+ super().__init__()
196
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
197
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
198
+
199
+ def forward(self, x):
200
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
201
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
202
+ return self.gamma * (x * Nx) + self.beta + x
203
+
204
+
205
+ class ResBlock(nn.Module):
206
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): # , num_heads=4, expansion=2):
207
+ super().__init__()
208
+ self.depthwise = Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
209
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
210
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
211
+ self.channelwise = nn.Sequential(
212
+ Linear(c + c_skip, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
213
+ )
214
+
215
+ self.gradient_checkpointing = False
216
+ self.factor = 1
217
+
218
+ def set_factor(self, k):
219
+ if self.factor!=1:
220
+ return
221
+ self.factor = k
222
+ self.depthwise.bias.data /= k
223
+ self.channelwise[4].weight.data /= k
224
+ self.channelwise[4].bias.data /= k
225
+
226
+ def set_gradient_checkpointing(self, value):
227
+ self.gradient_checkpointing = value
228
+
229
+ def forward_body(self, x, x_skip=None):
230
+ x_res = x
231
+ #x = x /self.factor
232
+ x = self.depthwise(x)
233
+ x = self.norm(x)
234
+ # if torch.any(torch.isnan(x)):
235
+ #print("nan in first norm")
236
+ if x_skip is not None:
237
+ x = torch.cat([x, x_skip], dim=1)
238
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)# * self.factor
239
+ # if torch.any(torch.isnan(x)):
240
+ #print("nan in second norm")
241
+ # result = x + x_res
242
+ # if check_scale(x) > 5:
243
+ # self.scale = 0.1
244
+ return x+ x_res
245
+
246
+ def forward(self, x, x_skip=None):
247
+ # if self.factor > 1:
248
+ #print("ResBlock: factor > 1")
249
+ if self.training and self.gradient_checkpointing:
250
+ # logger.info("ResnetBlock2D: gradient_checkpointing")
251
+
252
+ def create_custom_forward(func):
253
+ def custom_forward(*inputs):
254
+ return func(*inputs)
255
+
256
+ return custom_forward
257
+
258
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, x_skip)
259
+ else:
260
+ x = self.forward_body(x, x_skip)
261
+
262
+ return x
263
+
264
+
265
+ class AttnBlock(nn.Module):
266
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
267
+ super().__init__()
268
+ self.self_attn = self_attn
269
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
270
+ self.attention = Attention2D(c, nhead, dropout)
271
+ self.kv_mapper = nn.Sequential(nn.SiLU(), Linear(c_cond, c))
272
+
273
+ self.gradient_checkpointing = False
274
+ self.factor = 1
275
+
276
+ def set_factor(self, k):
277
+ if self.factor!=1:
278
+ return
279
+ self.factor = k
280
+ self.attention.attn.out_proj.weight.data /= k
281
+ if self.attention.attn.out_proj.bias is not None:
282
+ self.attention.attn.out_proj.bias.data /= k
283
+
284
+ def set_gradient_checkpointing(self, value):
285
+ self.gradient_checkpointing = value
286
+
287
+ def forward_body(self, x, kv):
288
+ kv = self.kv_mapper(kv)
289
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn) #* self.factor
290
+ return x
291
+
292
+ def forward(self, x, kv):
293
+ # if self.factor > 1:
294
+ #print("AttnBlock: factor > 1")
295
+ if self.training and self.gradient_checkpointing:
296
+ # logger.info("AttnBlock: gradient_checkpointing")
297
+
298
+ def create_custom_forward(func):
299
+ def custom_forward(*inputs):
300
+ return func(*inputs)
301
+
302
+ return custom_forward
303
+
304
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, kv)
305
+ else:
306
+ x = self.forward_body(x, kv)
307
+
308
+ return x
309
+
310
+
311
+ class FeedForwardBlock(nn.Module):
312
+ def __init__(self, c, dropout=0.0):
313
+ super().__init__()
314
+ self.norm = LayerNorm2d(c, elementwise_affine=False, eps=1e-6)
315
+ self.channelwise = nn.Sequential(
316
+ Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), Linear(c * 4, c)
317
+ )
318
+
319
+ self.gradient_checkpointing = False
320
+
321
+ def set_gradient_checkpointing(self, value):
322
+ self.gradient_checkpointing = value
323
+
324
+ def forward_body(self, x):
325
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
326
+ return x
327
+
328
+ def forward(self, x):
329
+ if self.training and self.gradient_checkpointing:
330
+ # logger.info("FeedForwardBlock: gradient_checkpointing")
331
+
332
+ def create_custom_forward(func):
333
+ def custom_forward(*inputs):
334
+ return func(*inputs)
335
+
336
+ return custom_forward
337
+
338
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
339
+ else:
340
+ x = self.forward_body(x)
341
+
342
+ return x
343
+
344
+
345
+ class TimestepBlock(nn.Module):
346
+ def __init__(self, c, c_timestep, conds=["sca"]):
347
+ super().__init__()
348
+ self.mapper = Linear(c_timestep, c * 2)
349
+ self.conds = conds
350
+ for cname in conds:
351
+ setattr(self, f"mapper_{cname}", Linear(c_timestep, c * 2))
352
+ self.factor = 1
353
+
354
+ def set_factor(self, k, ext_k):
355
+ if self.factor!=1:
356
+ return
357
+ #print(f"TimestepBlock: factor = {k}, ext_k = {ext_k}")
358
+ self.factor = k
359
+ k_factor = k/ext_k
360
+ a_weight_factor = 1/k_factor
361
+ b_weight_factor = 1/k
362
+ a_bias_offset = - ((k_factor - 1)/(k_factor))/(len(self.conds) + 1)
363
+
364
+ for module in [self.mapper, *(getattr(self, f"mapper_{cname}") for cname in self.conds)]:
365
+ a_bias, b_bias = module.bias.data.chunk(2, dim=0)
366
+ a_weight, b_weight = module.weight.data.chunk(2, dim=0)
367
+ module.weight.data.copy_(
368
+ torch.concat([
369
+ a_weight * a_weight_factor,
370
+ b_weight * b_weight_factor
371
+ ])
372
+ )
373
+ module.bias.data.copy_(
374
+ torch.concat([
375
+ a_bias * a_weight_factor + a_bias_offset,
376
+ b_bias * b_weight_factor
377
+ ])
378
+ )
379
+
380
+ def forward(self, x, t):
381
+ # if self.factor > 1:
382
+ #print("TimestepBlock: factor > 1")
383
+ t = t.chunk(len(self.conds) + 1, dim=1)
384
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
385
+ for i, c in enumerate(self.conds):
386
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
387
+ a, b = a + ac, b + bc
388
+ return (x * (1 + a) + b) # * self.factor
389
+
390
+
391
+ class UpDownBlock2d(nn.Module):
392
+ def __init__(self, c_in, c_out, mode, enabled=True):
393
+ super().__init__()
394
+ assert mode in ["up", "down"]
395
+ interpolation = (
396
+ nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) if enabled else nn.Identity()
397
+ )
398
+ mapping = nn.Conv2d(c_in, c_out, kernel_size=1)
399
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == "up" else [mapping, interpolation])
400
+
401
+ self.mode = mode
402
+
403
+ self.gradient_checkpointing = False
404
+
405
+ def set_gradient_checkpointing(self, value):
406
+ self.gradient_checkpointing = value
407
+
408
+ def forward_body(self, x):
409
+ org_dtype = x.dtype
410
+ for i, block in enumerate(self.blocks):
411
+ # 公式の実装では、常に float で計算しているが、すこしでもメモリを節約するために bfloat16 + Upsample のみ float に変換する
412
+ # In the official implementation, it always calculates in float, but for the sake of saving memory, it converts to float only for bfloat16 + Upsample
413
+ if x.dtype == torch.bfloat16 and (self.mode == "up" and i == 0 or self.mode != "up" and i == 1):
414
+ x = x.float()
415
+ x = block(x)
416
+ x = x.to(org_dtype)
417
+ return x
418
+
419
+ def forward(self, x):
420
+ if self.training and self.gradient_checkpointing:
421
+ # logger.info("UpDownBlock2d: gradient_checkpointing")
422
+
423
+ def create_custom_forward(func):
424
+ def custom_forward(*inputs):
425
+ return func(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x)
430
+ else:
431
+ x = self.forward_body(x)
432
+
433
+ return x
434
+
435
+
436
+ class StageAResBlock(nn.Module):
437
+ def __init__(self, c, c_hidden):
438
+ super().__init__()
439
+ # depthwise/attention
440
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
441
+ self.depthwise = nn.Sequential(nn.ReplicationPad2d(1), nn.Conv2d(c, c, kernel_size=3, groups=c))
442
+
443
+ # channelwise
444
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
445
+ self.channelwise = nn.Sequential(
446
+ nn.Linear(c, c_hidden),
447
+ nn.GELU(),
448
+ nn.Linear(c_hidden, c),
449
+ )
450
+
451
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
452
+
453
+ # Init weights
454
+ def _basic_init(module):
455
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
456
+ torch.nn.init.xavier_uniform_(module.weight)
457
+ if module.bias is not None:
458
+ nn.init.constant_(module.bias, 0)
459
+
460
+ self.apply(_basic_init)
461
+
462
+ def _norm(self, x, norm):
463
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
464
+
465
+ def forward(self, x):
466
+ mods = self.gammas
467
+
468
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
469
+ x = x + self.depthwise(x_temp) * mods[2]
470
+
471
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
472
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
473
+
474
+ return x
475
+
476
+
477
+ class StageA(nn.Module):
478
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192, scale_factor=0.43): # 0.3764
479
+ super().__init__()
480
+ self.c_latent = c_latent
481
+ self.scale_factor = scale_factor
482
+ c_levels = [c_hidden // (2**i) for i in reversed(range(levels))]
483
+
484
+ # Encoder blocks
485
+ self.in_block = nn.Sequential(nn.PixelUnshuffle(2), nn.Conv2d(3 * 4, c_levels[0], kernel_size=1))
486
+ down_blocks = []
487
+ for i in range(levels):
488
+ if i > 0:
489
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
490
+ block = StageAResBlock(c_levels[i], c_levels[i] * 4)
491
+ down_blocks.append(block)
492
+ down_blocks.append(
493
+ nn.Sequential(
494
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
495
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
496
+ )
497
+ )
498
+ self.down_blocks = nn.Sequential(*down_blocks)
499
+ self.down_blocks[0]
500
+
501
+ self.codebook_size = codebook_size
502
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
503
+
504
+ # Decoder blocks
505
+ up_blocks = [nn.Sequential(nn.Conv2d(c_latent, c_levels[-1], kernel_size=1))]
506
+ for i in range(levels):
507
+ for j in range(bottleneck_blocks if i == 0 else 1):
508
+ block = StageAResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
509
+ up_blocks.append(block)
510
+ if i < levels - 1:
511
+ up_blocks.append(
512
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)
513
+ )
514
+ self.up_blocks = nn.Sequential(*up_blocks)
515
+ self.out_block = nn.Sequential(
516
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
517
+ nn.PixelShuffle(2),
518
+ )
519
+
520
+ def encode(self, x, quantize=False):
521
+ x = self.in_block(x)
522
+ x = self.down_blocks(x)
523
+ if quantize:
524
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
525
+ return qe / self.scale_factor, x / self.scale_factor, indices, vq_loss + commit_loss * 0.25
526
+ else:
527
+ return x / self.scale_factor, None, None, None
528
+
529
+ def decode(self, x):
530
+ x = x * self.scale_factor
531
+ x = self.up_blocks(x)
532
+ x = self.out_block(x)
533
+ return x
534
+
535
+ def forward(self, x, quantize=False):
536
+ qe, x, _, vq_loss = self.encode(x, quantize)
537
+ x = self.decode(qe)
538
+ return x, vq_loss
539
+
540
+
541
+ r"""
542
+
543
+ https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_b_3b.yaml
544
+
545
+ # GLOBAL STUFF
546
+ model_version: 3B
547
+ dtype: bfloat16
548
+
549
+ # For demonstration purposes in reconstruct_images.ipynb
550
+ webdataset_path: file:inference/imagenet_1024.tar
551
+ batch_size: 4
552
+ image_size: 1024
553
+ grad_accum_steps: 1
554
+
555
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
556
+ stage_a_checkpoint_path: models/stage_a.safetensors
557
+ generator_checkpoint_path: models/stage_b_bf16.safetensors
558
+ """
559
+
560
+
561
+ class StageB(nn.Module):
562
+ def __init__(
563
+ self,
564
+ c_in=4,
565
+ c_out=4,
566
+ c_r=64,
567
+ patch_size=2,
568
+ c_cond=1280,
569
+ c_hidden=[320, 640, 1280, 1280],
570
+ nhead=[-1, -1, 20, 20],
571
+ blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
572
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]],
573
+ level_config=["CT", "CT", "CTA", "CTA"],
574
+ c_clip=1280,
575
+ c_clip_seq=4,
576
+ c_effnet=16,
577
+ c_pixels=3,
578
+ kernel_size=3,
579
+ dropout=[0, 0, 0.1, 0.1],
580
+ self_attn=True,
581
+ t_conds=["sca"],
582
+ ):
583
+ super().__init__()
584
+ self.c_r = c_r
585
+ self.t_conds = t_conds
586
+ self.c_clip_seq = c_clip_seq
587
+ if not isinstance(dropout, list):
588
+ dropout = [dropout] * len(c_hidden)
589
+ if not isinstance(self_attn, list):
590
+ self_attn = [self_attn] * len(c_hidden)
591
+
592
+ # CONDITIONING
593
+ self.effnet_mapper = nn.Sequential(
594
+ nn.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1),
595
+ nn.GELU(),
596
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
597
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
598
+ )
599
+ self.pixels_mapper = nn.Sequential(
600
+ nn.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1),
601
+ nn.GELU(),
602
+ nn.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1),
603
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
604
+ )
605
+ self.clip_mapper = nn.Linear(c_clip, c_cond * c_clip_seq)
606
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
607
+
608
+ self.embedding = nn.Sequential(
609
+ nn.PixelUnshuffle(patch_size),
610
+ nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
611
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
612
+ )
613
+
614
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
615
+ if block_type == "C":
616
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
617
+ elif block_type == "A":
618
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
619
+ elif block_type == "F":
620
+ return FeedForwardBlock(c_hidden, dropout=dropout)
621
+ elif block_type == "T":
622
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
623
+ else:
624
+ raise Exception(f"Block type {block_type} not supported")
625
+
626
+ # BLOCKS
627
+ # -- down blocks
628
+ self.down_blocks = nn.ModuleList()
629
+ self.down_downscalers = nn.ModuleList()
630
+ self.down_repeat_mappers = nn.ModuleList()
631
+ for i in range(len(c_hidden)):
632
+ if i > 0:
633
+ self.down_downscalers.append(
634
+ nn.Sequential(
635
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
636
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
637
+ )
638
+ )
639
+ else:
640
+ self.down_downscalers.append(nn.Identity())
641
+ down_block = nn.ModuleList()
642
+ for _ in range(blocks[0][i]):
643
+ for block_type in level_config[i]:
644
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
645
+ down_block.append(block)
646
+ self.down_blocks.append(down_block)
647
+ if block_repeat is not None:
648
+ block_repeat_mappers = nn.ModuleList()
649
+ for _ in range(block_repeat[0][i] - 1):
650
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
651
+ self.down_repeat_mappers.append(block_repeat_mappers)
652
+
653
+ # -- up blocks
654
+ self.up_blocks = nn.ModuleList()
655
+ self.up_upscalers = nn.ModuleList()
656
+ self.up_repeat_mappers = nn.ModuleList()
657
+ for i in reversed(range(len(c_hidden))):
658
+ if i > 0:
659
+ self.up_upscalers.append(
660
+ nn.Sequential(
661
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
662
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
663
+ )
664
+ )
665
+ else:
666
+ self.up_upscalers.append(nn.Identity())
667
+ up_block = nn.ModuleList()
668
+ for j in range(blocks[1][::-1][i]):
669
+ for k, block_type in enumerate(level_config[i]):
670
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
671
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
672
+ up_block.append(block)
673
+ self.up_blocks.append(up_block)
674
+ if block_repeat is not None:
675
+ block_repeat_mappers = nn.ModuleList()
676
+ for _ in range(block_repeat[1][::-1][i] - 1):
677
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
678
+ self.up_repeat_mappers.append(block_repeat_mappers)
679
+
680
+ # OUTPUT
681
+ self.clf = nn.Sequential(
682
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
683
+ nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
684
+ nn.PixelShuffle(patch_size),
685
+ )
686
+
687
+ # --- WEIGHT INIT ---
688
+ self.apply(self._init_weights) # General init
689
+ nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
690
+ nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
691
+ nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
692
+ nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
693
+ nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
694
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
695
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
696
+
697
+ # blocks
698
+ for level_block in self.down_blocks + self.up_blocks:
699
+ for block in level_block:
700
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
701
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
702
+ elif isinstance(block, TimestepBlock):
703
+ for layer in block.modules():
704
+ if isinstance(layer, nn.Linear):
705
+ nn.init.constant_(layer.weight, 0)
706
+
707
+ def _init_weights(self, m):
708
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
709
+ torch.nn.init.xavier_uniform_(m.weight)
710
+ if m.bias is not None:
711
+ nn.init.constant_(m.bias, 0)
712
+
713
+ def gen_r_embedding(self, r, max_positions=10000):
714
+ r = r * max_positions
715
+ half_dim = self.c_r // 2
716
+ emb = math.log(max_positions) / (half_dim - 1)
717
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
718
+ emb = r[:, None] * emb[None, :]
719
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
720
+ if self.c_r % 2 == 1: # zero pad
721
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
722
+ return emb
723
+
724
+ def gen_c_embeddings(self, clip):
725
+ if len(clip.shape) == 2:
726
+ clip = clip.unsqueeze(1)
727
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
728
+ clip = self.clip_norm(clip)
729
+ return clip
730
+
731
+ def _down_encode(self, x, r_embed, clip):
732
+ level_outputs = []
733
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
734
+ for down_block, downscaler, repmap in block_group:
735
+ x = downscaler(x)
736
+ for i in range(len(repmap) + 1):
737
+ for block in down_block:
738
+ if isinstance(block, ResBlock) or (
739
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
740
+ ):
741
+ x = block(x)
742
+ elif isinstance(block, AttnBlock) or (
743
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
744
+ ):
745
+ x = block(x, clip)
746
+ elif isinstance(block, TimestepBlock) or (
747
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
748
+ ):
749
+ x = block(x, r_embed)
750
+ else:
751
+ x = block(x)
752
+ if i < len(repmap):
753
+ x = repmap[i](x)
754
+ level_outputs.insert(0, x)
755
+ return level_outputs
756
+
757
+ def _up_decode(self, level_outputs, r_embed, clip):
758
+ x = level_outputs[0]
759
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
760
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
761
+ for j in range(len(repmap) + 1):
762
+ for k, block in enumerate(up_block):
763
+ if isinstance(block, ResBlock) or (
764
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
765
+ ):
766
+ skip = level_outputs[i] if k == 0 and i > 0 else None
767
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
768
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
769
+ x = block(x, skip)
770
+ elif isinstance(block, AttnBlock) or (
771
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
772
+ ):
773
+ x = block(x, clip)
774
+ elif isinstance(block, TimestepBlock) or (
775
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
776
+ ):
777
+ x = block(x, r_embed)
778
+ else:
779
+ x = block(x)
780
+ if j < len(repmap):
781
+ x = repmap[j](x)
782
+ x = upscaler(x)
783
+ return x
784
+
785
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
786
+ if pixels is None:
787
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
788
+
789
+ # Process the conditioning embeddings
790
+ r_embed = self.gen_r_embedding(r)
791
+ for c in self.t_conds:
792
+ t_cond = kwargs.get(c, torch.zeros_like(r))
793
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
794
+ clip = self.gen_c_embeddings(clip)
795
+
796
+ # Model Blocks
797
+ x = self.embedding(x)
798
+ x = x + self.effnet_mapper(
799
+ nn.functional.interpolate(effnet.float(), size=x.shape[-2:], mode="bilinear", align_corners=True)
800
+ )
801
+ x = x + nn.functional.interpolate(
802
+ self.pixels_mapper(pixels).float(), size=x.shape[-2:], mode="bilinear", align_corners=True
803
+ )
804
+ level_outputs = self._down_encode(x, r_embed, clip)
805
+ x = self._up_decode(level_outputs, r_embed, clip)
806
+ return self.clf(x)
807
+
808
+ def update_weights_ema(self, src_model, beta=0.999):
809
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
810
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
811
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
812
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
813
+
814
+
815
+ r"""
816
+
817
+ https://github.com/Stability-AI/StableCascade/blob/master/configs/inference/stage_c_3b.yaml
818
+
819
+ # GLOBAL STUFF
820
+ model_version: 3.6B
821
+ dtype: bfloat16
822
+
823
+ effnet_checkpoint_path: models/effnet_encoder.safetensors
824
+ previewer_checkpoint_path: models/previewer.safetensors
825
+ generator_checkpoint_path: models/stage_c_bf16.safetensors
826
+ """
827
+
828
+
829
+ class StageC(nn.Module):
830
+ def __init__(
831
+ self,
832
+ c_in=16,
833
+ c_out=16,
834
+ c_r=64,
835
+ patch_size=1,
836
+ c_cond=2048,
837
+ c_hidden=[2048, 2048],
838
+ nhead=[32, 32],
839
+ blocks=[[8, 24], [24, 8]],
840
+ block_repeat=[[1, 1], [1, 1]],
841
+ level_config=["CTA", "CTA"],
842
+ c_clip_text=1280,
843
+ c_clip_text_pooled=1280,
844
+ c_clip_img=768,
845
+ c_clip_seq=4,
846
+ kernel_size=3,
847
+ dropout=[0.1, 0.1],
848
+ self_attn=True,
849
+ t_conds=["sca", "crp"],
850
+ switch_level=[False],
851
+ ):
852
+ super().__init__()
853
+ self.c_r = c_r
854
+ self.t_conds = t_conds
855
+ self.c_clip_seq = c_clip_seq
856
+ if not isinstance(dropout, list):
857
+ dropout = [dropout] * len(c_hidden)
858
+ if not isinstance(self_attn, list):
859
+ self_attn = [self_attn] * len(c_hidden)
860
+
861
+ # CONDITIONING
862
+ self.clip_txt_mapper = nn.Linear(c_clip_text, c_cond)
863
+ self.clip_txt_pooled_mapper = nn.Linear(c_clip_text_pooled, c_cond * c_clip_seq)
864
+ self.clip_img_mapper = nn.Linear(c_clip_img, c_cond * c_clip_seq)
865
+ self.clip_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
866
+
867
+ self.embedding = nn.Sequential(
868
+ nn.PixelUnshuffle(patch_size),
869
+ nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
870
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
871
+ )
872
+
873
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
874
+ if block_type == "C":
875
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
876
+ elif block_type == "A":
877
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout)
878
+ elif block_type == "F":
879
+ return FeedForwardBlock(c_hidden, dropout=dropout)
880
+ elif block_type == "T":
881
+ return TimestepBlock(c_hidden, c_r, conds=t_conds)
882
+ else:
883
+ raise Exception(f"Block type {block_type} not supported")
884
+
885
+ # BLOCKS
886
+ # -- down blocks
887
+ self.down_blocks = nn.ModuleList()
888
+ self.down_downscalers = nn.ModuleList()
889
+ self.down_repeat_mappers = nn.ModuleList()
890
+ for i in range(len(c_hidden)):
891
+ if i > 0:
892
+ self.down_downscalers.append(
893
+ nn.Sequential(
894
+ LayerNorm2d(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
895
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode="down", enabled=switch_level[i - 1]),
896
+ )
897
+ )
898
+ else:
899
+ self.down_downscalers.append(nn.Identity())
900
+ down_block = nn.ModuleList()
901
+ for _ in range(blocks[0][i]):
902
+ for block_type in level_config[i]:
903
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
904
+ down_block.append(block)
905
+ self.down_blocks.append(down_block)
906
+ if block_repeat is not None:
907
+ block_repeat_mappers = nn.ModuleList()
908
+ for _ in range(block_repeat[0][i] - 1):
909
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
910
+ self.down_repeat_mappers.append(block_repeat_mappers)
911
+
912
+ # -- up blocks
913
+ self.up_blocks = nn.ModuleList()
914
+ self.up_upscalers = nn.ModuleList()
915
+ self.up_repeat_mappers = nn.ModuleList()
916
+ for i in reversed(range(len(c_hidden))):
917
+ if i > 0:
918
+ self.up_upscalers.append(
919
+ nn.Sequential(
920
+ LayerNorm2d(c_hidden[i], elementwise_affine=False, eps=1e-6),
921
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode="up", enabled=switch_level[i - 1]),
922
+ )
923
+ )
924
+ else:
925
+ self.up_upscalers.append(nn.Identity())
926
+ up_block = nn.ModuleList()
927
+ for j in range(blocks[1][::-1][i]):
928
+ for k, block_type in enumerate(level_config[i]):
929
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
930
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i], self_attn=self_attn[i])
931
+ up_block.append(block)
932
+ self.up_blocks.append(up_block)
933
+ if block_repeat is not None:
934
+ block_repeat_mappers = nn.ModuleList()
935
+ for _ in range(block_repeat[1][::-1][i] - 1):
936
+ block_repeat_mappers.append(nn.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1))
937
+ self.up_repeat_mappers.append(block_repeat_mappers)
938
+
939
+ # OUTPUT
940
+ self.clf = nn.Sequential(
941
+ LayerNorm2d(c_hidden[0], elementwise_affine=False, eps=1e-6),
942
+ nn.Conv2d(c_hidden[0], c_out * (patch_size**2), kernel_size=1),
943
+ nn.PixelShuffle(patch_size),
944
+ )
945
+
946
+ # --- WEIGHT INIT ---
947
+ self.apply(self._init_weights) # General init
948
+ nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
949
+ nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
950
+ nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
951
+ torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
952
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
953
+
954
+ # blocks
955
+ for level_block in self.down_blocks + self.up_blocks:
956
+ for block in level_block:
957
+ if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
958
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
959
+ elif isinstance(block, TimestepBlock):
960
+ for layer in block.modules():
961
+ if isinstance(layer, nn.Linear):
962
+ nn.init.constant_(layer.weight, 0)
963
+
964
+ def _init_weights(self, m):
965
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
966
+ torch.nn.init.xavier_uniform_(m.weight)
967
+ if m.bias is not None:
968
+ nn.init.constant_(m.bias, 0)
969
+
970
+ def set_gradient_checkpointing(self, value):
971
+ for block in self.down_blocks + self.up_blocks:
972
+ for layer in block:
973
+ if hasattr(layer, "set_gradient_checkpointing"):
974
+ layer.set_gradient_checkpointing(value)
975
+
976
+ def gen_r_embedding(self, r, max_positions=10000):
977
+ r = r * max_positions
978
+ half_dim = self.c_r // 2
979
+ emb = math.log(max_positions) / (half_dim - 1)
980
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
981
+ emb = r[:, None] * emb[None, :]
982
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
983
+ if self.c_r % 2 == 1: # zero pad
984
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
985
+ return emb
986
+
987
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
988
+ clip_txt = self.clip_txt_mapper(clip_txt)
989
+ if len(clip_txt_pooled.shape) == 2:
990
+ clip_txt_pool = clip_txt_pooled.unsqueeze(1)
991
+ if len(clip_img.shape) == 2:
992
+ clip_img = clip_img.unsqueeze(1)
993
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(
994
+ clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1
995
+ )
996
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
997
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
998
+ clip = self.clip_norm(clip)
999
+ return clip
1000
+
1001
+ def _down_encode(self, x, r_embed, clip, cnet=None):
1002
+ level_outputs = []
1003
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
1004
+ for down_block, downscaler, repmap in block_group:
1005
+ x = downscaler(x)
1006
+ for i in range(len(repmap) + 1):
1007
+ for block in down_block:
1008
+ if isinstance(block, ResBlock) or (
1009
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
1010
+ ):
1011
+ if cnet is not None:
1012
+ next_cnet = cnet()
1013
+ if next_cnet is not None:
1014
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
1015
+ x = block(x)
1016
+ elif isinstance(block, AttnBlock) or (
1017
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
1018
+ ):
1019
+ x = block(x, clip)
1020
+ elif isinstance(block, TimestepBlock) or (
1021
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
1022
+ ):
1023
+ x = block(x, r_embed)
1024
+ else:
1025
+ x = block(x)
1026
+ if i < len(repmap):
1027
+ x = repmap[i](x)
1028
+ level_outputs.insert(0, x)
1029
+ return level_outputs
1030
+
1031
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
1032
+ x = level_outputs[0]
1033
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
1034
+ now_factor = 1
1035
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
1036
+ for j in range(len(repmap) + 1):
1037
+ for k, block in enumerate(up_block):
1038
+ # if getattr(block, "factor", 1) > 1:
1039
+ # now_factor = -getattr(block, "factor", 1)
1040
+ # scale = check_scale(x)
1041
+ # if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)):
1042
+ #print('='*55)
1043
+ #print(f"in: {i} {j} {k}")
1044
+ #print("up", scale)
1045
+ if isinstance(block, ResBlock) or (
1046
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, ResBlock)
1047
+ ):
1048
+ skip = level_outputs[i] if k == 0 and i > 0 else None
1049
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
1050
+ x = torch.nn.functional.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True)
1051
+ if cnet is not None:
1052
+ next_cnet = cnet()
1053
+ if next_cnet is not None:
1054
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode="bilinear", align_corners=True)
1055
+ x = block(x, skip)
1056
+ # if now_factor > 1 and block.factor == 1:
1057
+ # block.set_factor(now_factor)
1058
+ elif isinstance(block, AttnBlock) or (
1059
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, AttnBlock)
1060
+ ):
1061
+ x = block(x, clip)
1062
+ # if now_factor > 1 and block.factor == 1:
1063
+ # block.set_factor(now_factor)
1064
+ elif isinstance(block, TimestepBlock) or (
1065
+ hasattr(block, "_fsdp_wrapped_module") and isinstance(block._fsdp_wrapped_module, TimestepBlock)
1066
+ ):
1067
+ x = block(x, r_embed)
1068
+ # scale = check_scale(x)
1069
+ # if now_factor > 1 and block.factor == 1:
1070
+ # block.set_factor(now_factor, now_factor)
1071
+ # pass
1072
+ # elif i==1:
1073
+ # now_factor = 5
1074
+ # block.set_factor(now_factor, 1)
1075
+ else:
1076
+ x = block(x)
1077
+ # scale = check_scale(x)
1078
+ # if scale > 5 or (now_factor < 0 and scale > (5/-now_factor)):
1079
+ #print(f"out: {i} {j} {k}", '='*50)
1080
+ #print("up", scale)
1081
+ #print(block.__class__.__name__, torch.sum(torch.isnan(x)))
1082
+ if j < len(repmap):
1083
+ x = repmap[j](x)
1084
+ #print('-- pre upscaler ---')
1085
+ #print(check_scale(x))
1086
+ x = upscaler(x)
1087
+ #print('-- post upscaler ---')
1088
+ #print(check_scale(x))
1089
+ # if now_factor > 1:
1090
+ # if isinstance(upscaler, UpDownBlock2d):
1091
+ # upscaler.blocks[1].weight.data /= now_factor
1092
+ # upscaler.blocks[1].bias.data /= now_factor
1093
+ # scale = check_scale(x)
1094
+ # if scale > 5:
1095
+ #print('='*50)
1096
+ #print("upscaler", check_scale(x))
1097
+ return x
1098
+
1099
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, cnet=None, **kwargs):
1100
+ # Process the conditioning embeddings
1101
+ r_embed = self.gen_r_embedding(r)
1102
+ for c in self.t_conds:
1103
+ t_cond = kwargs.get(c, torch.zeros_like(r))
1104
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond)], dim=1)
1105
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
1106
+
1107
+ # Model Blocks
1108
+ x = self.embedding(x)
1109
+ #print(check_scale(x))
1110
+ # ControlNet is not supported yet
1111
+ # if cnet is not None:
1112
+ # cnet = ControlNetDeliverer(cnet)
1113
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
1114
+ x1 = self._up_decode(level_outputs, r_embed, clip, cnet)
1115
+ result1 = self.clf(x1)
1116
+ return result1
1117
+ # self.half()
1118
+ sd = self.state_dict()
1119
+ # x2 = self._up_decode(level_outputs, r_embed, clip, cnet)
1120
+ # result2 = self.clf(x2)
1121
+ #print(torch.nn.functional.mse_loss(result1, result2))
1122
+ from safetensors.torch import save_file
1123
+ save_file(sd, 'factor5_pass4.safetensors')
1124
+ raise Exception("Early Stop")
1125
+
1126
+ def update_weights_ema(self, src_model, beta=0.999):
1127
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
1128
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
1129
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
1130
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
1131
+
1132
+ @property
1133
+ def device(self):
1134
+ return next(self.parameters()).device
1135
+
1136
+ @property
1137
+ def dtype(self):
1138
+ return next(self.parameters()).dtype
1139
+
1140
+
1141
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
1142
+ class Previewer(nn.Module):
1143
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
1144
+ super().__init__()
1145
+ self.blocks = nn.Sequential(
1146
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
1147
+ nn.GELU(),
1148
+ nn.BatchNorm2d(c_hidden),
1149
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
1150
+ nn.GELU(),
1151
+ nn.BatchNorm2d(c_hidden),
1152
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
1153
+ nn.GELU(),
1154
+ nn.BatchNorm2d(c_hidden // 2),
1155
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
1156
+ nn.GELU(),
1157
+ nn.BatchNorm2d(c_hidden // 2),
1158
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
1159
+ nn.GELU(),
1160
+ nn.BatchNorm2d(c_hidden // 4),
1161
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
1162
+ nn.GELU(),
1163
+ nn.BatchNorm2d(c_hidden // 4),
1164
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
1165
+ nn.GELU(),
1166
+ nn.BatchNorm2d(c_hidden // 4),
1167
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
1168
+ nn.GELU(),
1169
+ nn.BatchNorm2d(c_hidden // 4),
1170
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
1171
+ )
1172
+
1173
+ def forward(self, x):
1174
+ return self.blocks(x)
1175
+
1176
+ @property
1177
+ def device(self):
1178
+ return next(self.parameters()).device
1179
+
1180
+ @property
1181
+ def dtype(self):
1182
+ return next(self.parameters()).dtype
1183
+
1184
+
1185
+ def get_clip_conditions(captions: Optional[List[str]], input_ids, tokenizer, text_model):
1186
+ # deprecated
1187
+
1188
+ # self, batch: dict, tokenizer, text_model, is_eval=False, is_unconditional=False, eval_image_embeds=False, return_fields=None
1189
+ # is_eval の処理をここでやるのは微妙なので別のところでやる
1190
+ # is_unconditional もここでやるのは微妙なので別のところでやる
1191
+ # clip_image はとりあえずサポートしない
1192
+ if captions is not None:
1193
+ clip_tokens_unpooled = tokenizer(
1194
+ captions, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
1195
+ ).to(text_model.device)
1196
+ text_encoder_output = text_model(**clip_tokens_unpooled, output_hidden_states=True)
1197
+ else:
1198
+ text_encoder_output = text_model(input_ids, output_hidden_states=True)
1199
+
1200
+ text_embeddings = text_encoder_output.hidden_states[-1]
1201
+ text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1)
1202
+
1203
+ return text_embeddings, text_pooled_embeddings
1204
+ # return {"clip_text": text_embeddings, "clip_text_pooled": text_pooled_embeddings} # , "clip_img": image_embeddings}
1205
+
1206
+
1207
+ # region gdf
1208
+
1209
+
1210
+ class SimpleSampler:
1211
+ def __init__(self, gdf):
1212
+ self.gdf = gdf
1213
+ self.current_step = -1
1214
+
1215
+ def __call__(self, *args, **kwargs):
1216
+ self.current_step += 1
1217
+ return self.step(*args, **kwargs)
1218
+
1219
+ def init_x(self, shape):
1220
+ return torch.randn(*shape)
1221
+
1222
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
1223
+ raise NotImplementedError("You should override the 'apply' function.")
1224
+
1225
+
1226
+ class DDIMSampler(SimpleSampler):
1227
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=0):
1228
+ a, b = self.gdf.input_scaler(logSNR)
1229
+ if len(a.shape) == 1:
1230
+ a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1))
1231
+
1232
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
1233
+ if len(a_prev.shape) == 1:
1234
+ a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
1235
+
1236
+ sigma_tau = eta * (b_prev**2 / b**2).sqrt() * (1 - a**2 / a_prev**2).sqrt() if eta > 0 else 0
1237
+ # x = a_prev * x0 + (1 - a_prev**2 - sigma_tau ** 2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
1238
+ x = a_prev * x0 + (b_prev**2 - sigma_tau**2).sqrt() * epsilon + sigma_tau * torch.randn_like(x0)
1239
+ return x
1240
+
1241
+
1242
+ class DDPMSampler(DDIMSampler):
1243
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev, eta=1):
1244
+ return super().step(x, x0, epsilon, logSNR, logSNR_prev, eta)
1245
+
1246
+
1247
+ class LCMSampler(SimpleSampler):
1248
+ def step(self, x, x0, epsilon, logSNR, logSNR_prev):
1249
+ a_prev, b_prev = self.gdf.input_scaler(logSNR_prev)
1250
+ if len(a_prev.shape) == 1:
1251
+ a_prev, b_prev = a_prev.view(-1, *[1] * (len(x0.shape) - 1)), b_prev.view(-1, *[1] * (len(x0.shape) - 1))
1252
+ return x0 * a_prev + torch.randn_like(epsilon) * b_prev
1253
+
1254
+
1255
+ class GDF:
1256
+ def __init__(self, schedule, input_scaler, target, noise_cond, loss_weight, offset_noise=0):
1257
+ self.schedule = schedule
1258
+ self.input_scaler = input_scaler
1259
+ self.target = target
1260
+ self.noise_cond = noise_cond
1261
+ self.loss_weight = loss_weight
1262
+ self.offset_noise = offset_noise
1263
+
1264
+ def setup_limits(self, stretch_max=True, stretch_min=True, shift=1):
1265
+ stretched_limits = self.input_scaler.setup_limits(self.schedule, self.input_scaler, stretch_max, stretch_min, shift)
1266
+ return stretched_limits
1267
+
1268
+ def diffuse(self, x0, epsilon=None, t=None, shift=1, loss_shift=1, offset=None):
1269
+ if epsilon is None:
1270
+ epsilon = torch.randn_like(x0)
1271
+ if self.offset_noise > 0:
1272
+ if offset is None:
1273
+ offset = torch.randn([x0.size(0), x0.size(1)] + [1] * (len(x0.shape) - 2)).to(x0.device)
1274
+ epsilon = epsilon + offset * self.offset_noise
1275
+ logSNR = self.schedule(x0.size(0) if t is None else t, shift=shift).to(x0.device)
1276
+ a, b = self.input_scaler(logSNR) # B
1277
+ if len(a.shape) == 1:
1278
+ a, b = a.view(-1, *[1] * (len(x0.shape) - 1)), b.view(-1, *[1] * (len(x0.shape) - 1)) # BxCxHxW
1279
+ target = self.target(x0, epsilon, logSNR, a, b)
1280
+
1281
+ # noised, noise, logSNR, t_cond
1282
+ return x0 * a + epsilon * b, epsilon, target, logSNR, self.noise_cond(logSNR), self.loss_weight(logSNR, shift=loss_shift)
1283
+
1284
+ def undiffuse(self, x, logSNR, pred):
1285
+ a, b = self.input_scaler(logSNR)
1286
+ if len(a.shape) == 1:
1287
+ a, b = a.view(-1, *[1] * (len(x.shape) - 1)), b.view(-1, *[1] * (len(x.shape) - 1))
1288
+ return self.target.x0(x, pred, logSNR, a, b), self.target.epsilon(x, pred, logSNR, a, b)
1289
+
1290
+ def sample(
1291
+ self,
1292
+ model,
1293
+ model_inputs,
1294
+ shape,
1295
+ unconditional_inputs=None,
1296
+ sampler=None,
1297
+ schedule=None,
1298
+ t_start=1.0,
1299
+ t_end=0.0,
1300
+ timesteps=20,
1301
+ x_init=None,
1302
+ cfg=3.0,
1303
+ cfg_t_stop=None,
1304
+ cfg_t_start=None,
1305
+ cfg_rho=0.7,
1306
+ sampler_params=None,
1307
+ shift=1,
1308
+ device="cpu",
1309
+ ):
1310
+ sampler_params = {} if sampler_params is None else sampler_params
1311
+ if sampler is None:
1312
+ sampler = DDPMSampler(self)
1313
+ r_range = torch.linspace(t_start, t_end, timesteps + 1)
1314
+ schedule = self.schedule if schedule is None else schedule
1315
+ logSNR_range = schedule(r_range, shift=shift)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(device)
1316
+
1317
+ x = sampler.init_x(shape).to(device) if x_init is None else x_init.clone()
1318
+ if cfg is not None:
1319
+ if unconditional_inputs is None:
1320
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
1321
+ model_inputs = {
1322
+ k: (
1323
+ torch.cat([v, v_u], dim=0)
1324
+ if isinstance(v, torch.Tensor)
1325
+ else (
1326
+ [
1327
+ (
1328
+ torch.cat([vi, vi_u], dim=0)
1329
+ if isinstance(vi, torch.Tensor) and isinstance(vi_u, torch.Tensor)
1330
+ else None
1331
+ )
1332
+ for vi, vi_u in zip(v, v_u)
1333
+ ]
1334
+ if isinstance(v, list)
1335
+ else (
1336
+ {vk: torch.cat([v[vk], v_u.get(vk, torch.zeros_like(v[vk]))], dim=0) for vk in v}
1337
+ if isinstance(v, dict)
1338
+ else None
1339
+ )
1340
+ )
1341
+ )
1342
+ for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())
1343
+ }
1344
+ for i in range(0, timesteps):
1345
+ noise_cond = self.noise_cond(logSNR_range[i])
1346
+ if (
1347
+ cfg is not None
1348
+ and (cfg_t_stop is None or r_range[i].item() >= cfg_t_stop)
1349
+ and (cfg_t_start is None or r_range[i].item() <= cfg_t_start)
1350
+ ):
1351
+ cfg_val = cfg
1352
+ if isinstance(cfg_val, (list, tuple)):
1353
+ assert len(cfg_val) == 2, "cfg must be a float or a list/tuple of length 2"
1354
+ cfg_val = cfg_val[0] * r_range[i].item() + cfg_val[1] * (1 - r_range[i].item())
1355
+ pred, pred_unconditional = model(torch.cat([x, x], dim=0), noise_cond.repeat(2), **model_inputs).chunk(2)
1356
+ pred_cfg = torch.lerp(pred_unconditional, pred, cfg_val)
1357
+ if cfg_rho > 0:
1358
+ std_pos, std_cfg = pred.std(), pred_cfg.std()
1359
+ pred = cfg_rho * (pred_cfg * std_pos / (std_cfg + 1e-9)) + pred_cfg * (1 - cfg_rho)
1360
+ else:
1361
+ pred = pred_cfg
1362
+ else:
1363
+ pred = model(x, noise_cond, **model_inputs)
1364
+ x0, epsilon = self.undiffuse(x, logSNR_range[i], pred)
1365
+ x = sampler(x, x0, epsilon, logSNR_range[i], logSNR_range[i + 1], **sampler_params)
1366
+ altered_vars = yield (x0, x, pred)
1367
+
1368
+ # Update some running variables if the user wants
1369
+ if altered_vars is not None:
1370
+ cfg = altered_vars.get("cfg", cfg)
1371
+ cfg_rho = altered_vars.get("cfg_rho", cfg_rho)
1372
+ sampler = altered_vars.get("sampler", sampler)
1373
+ model_inputs = altered_vars.get("model_inputs", model_inputs)
1374
+ x = altered_vars.get("x", x)
1375
+ x_init = altered_vars.get("x_init", x_init)
1376
+
1377
+
1378
+ class BaseSchedule:
1379
+ def __init__(self, *args, force_limits=True, discrete_steps=None, shift=1, **kwargs):
1380
+ self.setup(*args, **kwargs)
1381
+ self.limits = None
1382
+ self.discrete_steps = discrete_steps
1383
+ self.shift = shift
1384
+ if force_limits:
1385
+ self.reset_limits()
1386
+
1387
+ def reset_limits(self, shift=1, disable=False):
1388
+ try:
1389
+ self.limits = None if disable else self(torch.tensor([1.0, 0.0]), shift=shift).tolist() # min, max
1390
+ return self.limits
1391
+ except Exception:
1392
+ #print("WARNING: this schedule doesn't support t and will be unbounded")
1393
+ return None
1394
+
1395
+ def setup(self, *args, **kwargs):
1396
+ raise NotImplementedError("this method needs to be overridden")
1397
+
1398
+ def schedule(self, *args, **kwargs):
1399
+ raise NotImplementedError("this method needs to be overridden")
1400
+
1401
+ def __call__(self, t, *args, shift=1, **kwargs):
1402
+ if isinstance(t, torch.Tensor):
1403
+ batch_size = None
1404
+ if self.discrete_steps is not None:
1405
+ if t.dtype != torch.long:
1406
+ t = (t * (self.discrete_steps - 1)).round().long()
1407
+ t = t / (self.discrete_steps - 1)
1408
+ t = t.clamp(0, 1)
1409
+ else:
1410
+ batch_size = t
1411
+ t = None
1412
+ logSNR = self.schedule(t, batch_size, *args, **kwargs)
1413
+ if shift * self.shift != 1:
1414
+ logSNR += 2 * np.log(1 / (shift * self.shift))
1415
+ if self.limits is not None:
1416
+ logSNR = logSNR.clamp(*self.limits)
1417
+ return logSNR
1418
+
1419
+
1420
+ class CosineSchedule(BaseSchedule):
1421
+ def setup(self, s=0.008, clamp_range=[0.0001, 0.9999], norm_instead=False):
1422
+ self.s = torch.tensor([s])
1423
+ self.clamp_range = clamp_range
1424
+ self.norm_instead = norm_instead
1425
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
1426
+
1427
+ def schedule(self, t, batch_size):
1428
+ if t is None:
1429
+ t = (1 - torch.rand(batch_size)).add(0.001).clamp(0.001, 1.0)
1430
+ s, min_var = self.s.to(t.device), self.min_var.to(t.device)
1431
+ var = torch.cos((s + t) / (1 + s) * torch.pi * 0.5).clamp(0, 1) ** 2 / min_var
1432
+ if self.norm_instead:
1433
+ var = var * (self.clamp_range[1] - self.clamp_range[0]) + self.clamp_range[0]
1434
+ else:
1435
+ var = var.clamp(*self.clamp_range)
1436
+ logSNR = (var / (1 - var)).log()
1437
+ return logSNR
1438
+
1439
+
1440
+ class BaseScaler:
1441
+ def __init__(self):
1442
+ self.stretched_limits = None
1443
+
1444
+ def setup_limits(self, schedule, input_scaler, stretch_max=True, stretch_min=True, shift=1):
1445
+ min_logSNR = schedule(torch.ones(1), shift=shift)
1446
+ max_logSNR = schedule(torch.zeros(1), shift=shift)
1447
+
1448
+ min_a, max_b = [v.item() for v in input_scaler(min_logSNR)] if stretch_max else [0, 1]
1449
+ max_a, min_b = [v.item() for v in input_scaler(max_logSNR)] if stretch_min else [1, 0]
1450
+ self.stretched_limits = [min_a, max_a, min_b, max_b]
1451
+ return self.stretched_limits
1452
+
1453
+ def stretch_limits(self, a, b):
1454
+ min_a, max_a, min_b, max_b = self.stretched_limits
1455
+ return (a - min_a) / (max_a - min_a), (b - min_b) / (max_b - min_b)
1456
+
1457
+ def scalers(self, logSNR):
1458
+ raise NotImplementedError("this method needs to be overridden")
1459
+
1460
+ def __call__(self, logSNR):
1461
+ a, b = self.scalers(logSNR)
1462
+ if self.stretched_limits is not None:
1463
+ a, b = self.stretch_limits(a, b)
1464
+ return a, b
1465
+
1466
+
1467
+ class VPScaler(BaseScaler):
1468
+ def scalers(self, logSNR):
1469
+ a_squared = logSNR.sigmoid()
1470
+ a = a_squared.sqrt()
1471
+ b = (1 - a_squared).sqrt()
1472
+ return a, b
1473
+
1474
+
1475
+ class EpsilonTarget:
1476
+ def __call__(self, x0, epsilon, logSNR, a, b):
1477
+ return epsilon
1478
+
1479
+ def x0(self, noised, pred, logSNR, a, b):
1480
+ return (noised - pred * b) / a
1481
+
1482
+ def epsilon(self, noised, pred, logSNR, a, b):
1483
+ return pred
1484
+
1485
+
1486
+ class BaseNoiseCond:
1487
+ def __init__(self, *args, shift=1, clamp_range=None, **kwargs):
1488
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
1489
+ self.shift = shift
1490
+ self.clamp_range = clamp_range
1491
+ self.setup(*args, **kwargs)
1492
+
1493
+ def setup(self, *args, **kwargs):
1494
+ pass # this method is optional, override it if required
1495
+
1496
+ def cond(self, logSNR):
1497
+ raise NotImplementedError("this method needs to be overridden")
1498
+
1499
+ def __call__(self, logSNR):
1500
+ if self.shift != 1:
1501
+ logSNR = logSNR.clone() + 2 * np.log(self.shift)
1502
+ return self.cond(logSNR).clamp(*self.clamp_range)
1503
+
1504
+
1505
+ class CosineTNoiseCond(BaseNoiseCond):
1506
+ def setup(self, s=0.008, clamp_range=[0, 1]): # [0.0001, 0.9999]
1507
+ self.s = torch.tensor([s])
1508
+ self.clamp_range = clamp_range
1509
+ self.min_var = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
1510
+
1511
+ def cond(self, logSNR):
1512
+ var = logSNR.sigmoid()
1513
+ var = var.clamp(*self.clamp_range)
1514
+ s, min_var = self.s.to(var.device), self.min_var.to(var.device)
1515
+ t = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
1516
+ return t
1517
+
1518
+
1519
+ # --- Loss Weighting
1520
+ class BaseLossWeight:
1521
+ def weight(self, logSNR):
1522
+ raise NotImplementedError("this method needs to be overridden")
1523
+
1524
+ def __call__(self, logSNR, *args, shift=1, clamp_range=None, **kwargs):
1525
+ clamp_range = [-1e9, 1e9] if clamp_range is None else clamp_range
1526
+ if shift != 1:
1527
+ logSNR = logSNR.clone() + 2 * np.log(shift)
1528
+ return self.weight(logSNR, *args, **kwargs).clamp(*clamp_range)
1529
+
1530
+
1531
+ # class ComposedLossWeight(BaseLossWeight):
1532
+ # def __init__(self, div, mul):
1533
+ # self.mul = [mul] if isinstance(mul, BaseLossWeight) else mul
1534
+ # self.div = [div] if isinstance(div, BaseLossWeight) else div
1535
+
1536
+ # def weight(self, logSNR):
1537
+ # prod, div = 1, 1
1538
+ # for m in self.mul:
1539
+ # prod *= m.weight(logSNR)
1540
+ # for d in self.div:
1541
+ # div *= d.weight(logSNR)
1542
+ # return prod/div
1543
+
1544
+ # class ConstantLossWeight(BaseLossWeight):
1545
+ # def __init__(self, v=1):
1546
+ # self.v = v
1547
+
1548
+ # def weight(self, logSNR):
1549
+ # return torch.ones_like(logSNR) * self.v
1550
+
1551
+ # class SNRLossWeight(BaseLossWeight):
1552
+ # def weight(self, logSNR):
1553
+ # return logSNR.exp()
1554
+
1555
+
1556
+ class P2LossWeight(BaseLossWeight):
1557
+ def __init__(self, k=1.0, gamma=1.0, s=1.0):
1558
+ self.k, self.gamma, self.s = k, gamma, s
1559
+
1560
+ def weight(self, logSNR):
1561
+ return (self.k + (logSNR * self.s).exp()) ** -self.gamma
1562
+
1563
+
1564
+ # class SNRPlusOneLossWeight(BaseLossWeight):
1565
+ # def weight(self, logSNR):
1566
+ # return logSNR.exp() + 1
1567
+
1568
+ # class MinSNRLossWeight(BaseLossWeight):
1569
+ # def __init__(self, max_snr=5):
1570
+ # self.max_snr = max_snr
1571
+
1572
+ # def weight(self, logSNR):
1573
+ # return logSNR.exp().clamp(max=self.max_snr)
1574
+
1575
+ # class MinSNRPlusOneLossWeight(BaseLossWeight):
1576
+ # def __init__(self, max_snr=5):
1577
+ # self.max_snr = max_snr
1578
+
1579
+ # def weight(self, logSNR):
1580
+ # return (logSNR.exp() + 1).clamp(max=self.max_snr)
1581
+
1582
+ # class TruncatedSNRLossWeight(BaseLossWeight):
1583
+ # def __init__(self, min_snr=1):
1584
+ # self.min_snr = min_snr
1585
+
1586
+ # def weight(self, logSNR):
1587
+ # return logSNR.exp().clamp(min=self.min_snr)
1588
+
1589
+ # class SechLossWeight(BaseLossWeight):
1590
+ # def __init__(self, div=2):
1591
+ # self.div = div
1592
+
1593
+ # def weight(self, logSNR):
1594
+ # return 1/(logSNR/self.div).cosh()
1595
+
1596
+ # class DebiasedLossWeight(BaseLossWeight):
1597
+ # def weight(self, logSNR):
1598
+ # return 1/logSNR.exp().sqrt()
1599
+
1600
+ # class SigmoidLossWeight(BaseLossWeight):
1601
+ # def __init__(self, s=1):
1602
+ # self.s = s
1603
+
1604
+ # def weight(self, logSNR):
1605
+ # return (logSNR * self.s).sigmoid()
1606
+
1607
+
1608
+ class AdaptiveLossWeight(BaseLossWeight):
1609
+ def __init__(self, logsnr_range=[-10, 10], buckets=300, weight_range=[1e-7, 1e7]):
1610
+ self.bucket_ranges = torch.linspace(logsnr_range[0], logsnr_range[1], buckets - 1)
1611
+ self.bucket_losses = torch.ones(buckets)
1612
+ self.weight_range = weight_range
1613
+
1614
+ def weight(self, logSNR):
1615
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR)
1616
+ return (1 / self.bucket_losses.to(logSNR.device)[indices]).clamp(*self.weight_range)
1617
+
1618
+ def update_buckets(self, logSNR, loss, beta=0.99):
1619
+ indices = torch.searchsorted(self.bucket_ranges.to(logSNR.device), logSNR).cpu()
1620
+ self.bucket_losses[indices] = self.bucket_losses[indices] * beta + loss.detach().cpu() * (1 - beta)
1621
+
1622
+
1623
+ # endregion gdf