aifeifei798 commited on
Commit
851bd03
1 Parent(s): 2c166e9

Upload 154 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. extras/expansion.py +129 -0
  2. extras/fooocus_expansion/README.md +12 -0
  3. extras/fooocus_expansion/config.json +40 -0
  4. extras/fooocus_expansion/huggingface-metadata.txt +5 -0
  5. extras/fooocus_expansion/merges.txt +0 -0
  6. extras/fooocus_expansion/positive.txt +642 -0
  7. extras/fooocus_expansion/pytorch_model.bin +3 -0
  8. extras/fooocus_expansion/special_tokens_map.json +5 -0
  9. extras/fooocus_expansion/tokenizer.json +0 -0
  10. extras/fooocus_expansion/tokenizer_config.json +10 -0
  11. extras/fooocus_expansion/vocab.json +0 -0
  12. ldm_patched/contrib/external.py +1954 -0
  13. ldm_patched/contrib/external_align_your_steps.py +55 -0
  14. ldm_patched/contrib/external_canny.py +301 -0
  15. ldm_patched/contrib/external_clip_sdxl.py +58 -0
  16. ldm_patched/contrib/external_compositing.py +204 -0
  17. ldm_patched/contrib/external_custom_sampler.py +316 -0
  18. ldm_patched/contrib/external_freelunch.py +115 -0
  19. ldm_patched/contrib/external_hypernetwork.py +121 -0
  20. ldm_patched/contrib/external_hypertile.py +85 -0
  21. ldm_patched/contrib/external_images.py +177 -0
  22. ldm_patched/contrib/external_latent.py +157 -0
  23. ldm_patched/contrib/external_mask.py +365 -0
  24. ldm_patched/contrib/external_model_advanced.py +188 -0
  25. ldm_patched/contrib/external_model_downscale.py +55 -0
  26. ldm_patched/contrib/external_model_merging.py +286 -0
  27. ldm_patched/contrib/external_perpneg.py +57 -0
  28. ldm_patched/contrib/external_photomaker.py +189 -0
  29. ldm_patched/contrib/external_post_processing.py +278 -0
  30. ldm_patched/contrib/external_rebatch.py +140 -0
  31. ldm_patched/contrib/external_sag.py +172 -0
  32. ldm_patched/contrib/external_sdupscale.py +49 -0
  33. ldm_patched/contrib/external_stable3d.py +104 -0
  34. ldm_patched/contrib/external_tomesd.py +179 -0
  35. ldm_patched/contrib/external_upscale_model.py +68 -0
  36. ldm_patched/contrib/external_video_model.py +108 -0
  37. ldm_patched/controlnet/cldm.py +312 -0
  38. ldm_patched/k_diffusion/sampling.py +908 -0
  39. ldm_patched/k_diffusion/utils.py +313 -0
  40. ldm_patched/ldm/models/autoencoder.py +228 -0
  41. ldm_patched/ldm/modules/attention.py +781 -0
  42. ldm_patched/ldm/modules/diffusionmodules/__init__.py +0 -0
  43. ldm_patched/ldm/modules/diffusionmodules/model.py +650 -0
  44. ldm_patched/ldm/modules/diffusionmodules/openaimodel.py +886 -0
  45. ldm_patched/ldm/modules/diffusionmodules/upscaling.py +85 -0
  46. ldm_patched/ldm/modules/diffusionmodules/util.py +304 -0
  47. ldm_patched/ldm/modules/distributions/__init__.py +0 -0
  48. ldm_patched/ldm/modules/distributions/distributions.py +92 -0
  49. ldm_patched/ldm/modules/ema.py +80 -0
  50. ldm_patched/ldm/modules/encoders/__init__.py +0 -0
extras/expansion.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fooocus GPT2 Expansion
2
+ # Algorithm created by Lvmin Zhang at 2023, Stanford
3
+ # If used inside Fooocus, any use is permitted.
4
+ # If used outside Fooocus, only non-commercial use is permitted (CC-By NC 4.0).
5
+ # This applies to the word list, vocab, model, and algorithm.
6
+
7
+
8
+ import os
9
+ import torch
10
+ import math
11
+ import ldm_patched.modules.model_management as model_management
12
+
13
+ from transformers.generation.logits_process import LogitsProcessorList
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
15
+ # from modules.config import path_fooocus_expansion
16
+ from ldm_patched.modules.model_patcher import ModelPatcher
17
+
18
+ path_fooocus_expansion ="extras/fooocus_expansion"
19
+ # limitation of np.random.seed(), called from transformers.set_seed()
20
+ SEED_LIMIT_NUMPY = 2**32
21
+ neg_inf = - 8192.0
22
+
23
+
24
+ def safe_str(x):
25
+ x = str(x)
26
+ for _ in range(16):
27
+ x = x.replace(' ', ' ')
28
+ return x.strip(",. \r\n")
29
+
30
+
31
+ def remove_pattern(x, pattern):
32
+ for p in pattern:
33
+ x = x.replace(p, '')
34
+ return x
35
+
36
+
37
+ class FooocusExpansion:
38
+ def __init__(self):
39
+ self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
40
+
41
+ positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
42
+ encoding='utf-8').read().splitlines()
43
+ positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
44
+
45
+ self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
46
+
47
+ debug_list = []
48
+ for k, v in self.tokenizer.vocab.items():
49
+ if k in positive_words:
50
+ self.logits_bias[0, v] = 0
51
+ debug_list.append(k[1:])
52
+
53
+ print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
54
+
55
+ # debug_list = '\n'.join(sorted(debug_list))
56
+ # print(debug_list)
57
+
58
+ # t11 = self.tokenizer(',', return_tensors="np")
59
+ # t198 = self.tokenizer('\n', return_tensors="np")
60
+ # eos = self.tokenizer.eos_token_id
61
+
62
+ self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
63
+ self.model.eval()
64
+
65
+ load_device = model_management.text_encoder_device()
66
+ offload_device = model_management.text_encoder_offload_device()
67
+
68
+ # MPS hack
69
+ if model_management.is_device_mps(load_device):
70
+ load_device = torch.device('cpu')
71
+ offload_device = torch.device('cpu')
72
+
73
+ use_fp16 = model_management.should_use_fp16(device=load_device)
74
+
75
+ if use_fp16:
76
+ self.model.half()
77
+
78
+ self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
79
+ print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
80
+
81
+ @torch.no_grad()
82
+ @torch.inference_mode()
83
+ def logits_processor(self, input_ids, scores):
84
+ assert scores.ndim == 2 and scores.shape[0] == 1
85
+ self.logits_bias = self.logits_bias.to(scores)
86
+
87
+ bias = self.logits_bias.clone()
88
+ bias[0, input_ids[0].to(bias.device).long()] = neg_inf
89
+ bias[0, 11] = 0
90
+
91
+ return scores + bias
92
+
93
+ @torch.no_grad()
94
+ @torch.inference_mode()
95
+ def __call__(self, prompt, seed):
96
+ if prompt == '':
97
+ return ''
98
+
99
+ if self.patcher.current_device != self.patcher.load_device:
100
+ print('Fooocus Expansion loaded by itself.')
101
+ model_management.load_model_gpu(self.patcher)
102
+
103
+ seed = int(seed) % SEED_LIMIT_NUMPY
104
+ set_seed(seed)
105
+ prompt = safe_str(prompt) + ','
106
+
107
+ tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
108
+ tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
109
+ tokenized_kwargs.data['attention_mask'] = tokenized_kwargs.data['attention_mask'].to(self.patcher.load_device)
110
+
111
+ current_token_length = int(tokenized_kwargs.data['input_ids'].shape[1])
112
+ max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
113
+ max_new_tokens = max_token_length - current_token_length
114
+
115
+ if max_new_tokens == 0:
116
+ return prompt[:-1]
117
+
118
+ # https://huggingface.co/blog/introducing-csearch
119
+ # https://huggingface.co/docs/transformers/generation_strategies
120
+ features = self.model.generate(**tokenized_kwargs,
121
+ top_k=100,
122
+ max_new_tokens=max_new_tokens,
123
+ do_sample=True,
124
+ logits_processor=LogitsProcessorList([self.logits_processor]))
125
+
126
+ response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
127
+ result = safe_str(response[0])
128
+
129
+ return result
extras/fooocus_expansion/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: agpl-3.0
3
+ ---
4
+
5
+ GPT2 Prompt Expansion model from [lllyasviel/Fooocus](https://github.com/lllyasviel/Fooocus)
6
+
7
+ Third-party [license terms](https://github.com/lllyasviel/Fooocus/blob/main/LICENSE)
8
+
9
+ ## Disclaimer
10
+ All trademarks, logos, and brand names are the property of their respective owners.
11
+ All company, product and service names used in this website and licensed applications are for identification purposes only.
12
+ Use of these names, trademarks, and brands does not imply endorsement.
extras/fooocus_expansion/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gpt2",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "GPT2LMHeadModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "bos_token_id": 50256,
9
+ "embd_pdrop": 0.1,
10
+ "eos_token_id": 50256,
11
+ "pad_token_id": 50256,
12
+ "initializer_range": 0.02,
13
+ "layer_norm_epsilon": 1e-05,
14
+ "model_type": "gpt2",
15
+ "n_ctx": 1024,
16
+ "n_embd": 768,
17
+ "n_head": 12,
18
+ "n_inner": null,
19
+ "n_layer": 12,
20
+ "n_positions": 1024,
21
+ "reorder_and_upcast_attn": false,
22
+ "resid_pdrop": 0.1,
23
+ "scale_attn_by_inverse_layer_idx": false,
24
+ "scale_attn_weights": true,
25
+ "summary_activation": null,
26
+ "summary_first_dropout": 0.1,
27
+ "summary_proj_to_labels": true,
28
+ "summary_type": "cls_index",
29
+ "summary_use_proj": true,
30
+ "task_specific_params": {
31
+ "text-generation": {
32
+ "do_sample": true,
33
+ "max_length": 50
34
+ }
35
+ },
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.23.0.dev0",
38
+ "use_cache": true,
39
+ "vocab_size": 50257
40
+ }
extras/fooocus_expansion/huggingface-metadata.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ url: https://huggingface.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2
2
+ branch: main
3
+ download date: 2024-04-20 17:49:07
4
+ sha256sum:
5
+ dd54cc90d95d2c72b97830e4b38f44a6521847284d5b9dbcfd16ba82779cdeb3 pytorch_model.bin
extras/fooocus_expansion/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
extras/fooocus_expansion/positive.txt ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ abundant
2
+ accelerated
3
+ accepted
4
+ accepting
5
+ acclaimed
6
+ accomplished
7
+ acknowledged
8
+ activated
9
+ adapted
10
+ adjusted
11
+ admirable
12
+ adorable
13
+ adorned
14
+ advanced
15
+ adventurous
16
+ advocated
17
+ aesthetic
18
+ affirmed
19
+ affluent
20
+ agile
21
+ aimed
22
+ aligned
23
+ alive
24
+ altered
25
+ amazing
26
+ ambient
27
+ amplified
28
+ analytical
29
+ animated
30
+ appealing
31
+ applauded
32
+ appreciated
33
+ ardent
34
+ aromatic
35
+ arranged
36
+ arresting
37
+ articulate
38
+ artistic
39
+ associated
40
+ assured
41
+ astonishing
42
+ astounding
43
+ atmosphere
44
+ attempted
45
+ attentive
46
+ attractive
47
+ authentic
48
+ authoritative
49
+ awarded
50
+ awesome
51
+ backed
52
+ background
53
+ baked
54
+ balance
55
+ balanced
56
+ balancing
57
+ beaten
58
+ beautiful
59
+ beloved
60
+ beneficial
61
+ benevolent
62
+ best
63
+ bestowed
64
+ blazing
65
+ blended
66
+ blessed
67
+ boosted
68
+ borne
69
+ brave
70
+ breathtaking
71
+ brewed
72
+ bright
73
+ brilliant
74
+ brought
75
+ built
76
+ burning
77
+ calm
78
+ calmed
79
+ candid
80
+ caring
81
+ carried
82
+ catchy
83
+ celebrated
84
+ celestial
85
+ certain
86
+ championed
87
+ changed
88
+ charismatic
89
+ charming
90
+ chased
91
+ cheered
92
+ cheerful
93
+ cherished
94
+ chic
95
+ chosen
96
+ cinematic
97
+ clad
98
+ classic
99
+ classy
100
+ clear
101
+ coached
102
+ coherent
103
+ collected
104
+ color
105
+ colorful
106
+ colors
107
+ colossal
108
+ combined
109
+ comforting
110
+ commanding
111
+ committed
112
+ compassionate
113
+ compatible
114
+ complete
115
+ complex
116
+ complimentary
117
+ composed
118
+ composition
119
+ comprehensive
120
+ conceived
121
+ conferred
122
+ confident
123
+ connected
124
+ considerable
125
+ considered
126
+ consistent
127
+ conspicuous
128
+ constructed
129
+ constructive
130
+ contemplated
131
+ contemporary
132
+ content
133
+ contrasted
134
+ conveyed
135
+ cooked
136
+ cool
137
+ coordinated
138
+ coupled
139
+ courageous
140
+ coveted
141
+ cozy
142
+ created
143
+ creative
144
+ credited
145
+ crisp
146
+ critical
147
+ cultivated
148
+ cured
149
+ curious
150
+ current
151
+ customized
152
+ cute
153
+ daring
154
+ darling
155
+ dazzling
156
+ decorated
157
+ decorative
158
+ dedicated
159
+ deep
160
+ defended
161
+ definitive
162
+ delicate
163
+ delightful
164
+ delivered
165
+ depicted
166
+ designed
167
+ desirable
168
+ desired
169
+ destined
170
+ detail
171
+ detailed
172
+ determined
173
+ developed
174
+ devoted
175
+ devout
176
+ diligent
177
+ direct
178
+ directed
179
+ discovered
180
+ dispatched
181
+ displayed
182
+ distilled
183
+ distinct
184
+ distinctive
185
+ distinguished
186
+ diverse
187
+ divine
188
+ dramatic
189
+ draped
190
+ dreamed
191
+ driven
192
+ dynamic
193
+ earnest
194
+ eased
195
+ ecstatic
196
+ educated
197
+ effective
198
+ elaborate
199
+ elegant
200
+ elevated
201
+ elite
202
+ eminent
203
+ emotional
204
+ empowered
205
+ empowering
206
+ enchanted
207
+ encouraged
208
+ endorsed
209
+ endowed
210
+ enduring
211
+ energetic
212
+ engaging
213
+ enhanced
214
+ enigmatic
215
+ enlightened
216
+ enormous
217
+ enticing
218
+ envisioned
219
+ epic
220
+ esteemed
221
+ eternal
222
+ everlasting
223
+ evolved
224
+ exalted
225
+ examining
226
+ excellent
227
+ exceptional
228
+ exciting
229
+ exclusive
230
+ exemplary
231
+ exotic
232
+ expansive
233
+ exposed
234
+ expressive
235
+ exquisite
236
+ extended
237
+ extraordinary
238
+ extremely
239
+ fabulous
240
+ facilitated
241
+ fair
242
+ faithful
243
+ famous
244
+ fancy
245
+ fantastic
246
+ fascinating
247
+ fashionable
248
+ fashioned
249
+ favorable
250
+ favored
251
+ fearless
252
+ fermented
253
+ fertile
254
+ festive
255
+ fiery
256
+ fine
257
+ finest
258
+ firm
259
+ fixed
260
+ flaming
261
+ flashing
262
+ flashy
263
+ flavored
264
+ flawless
265
+ flourishing
266
+ flowing
267
+ focus
268
+ focused
269
+ formal
270
+ formed
271
+ fortunate
272
+ fostering
273
+ frank
274
+ fresh
275
+ fried
276
+ friendly
277
+ fruitful
278
+ fulfilled
279
+ full
280
+ futuristic
281
+ generous
282
+ gentle
283
+ genuine
284
+ gifted
285
+ gigantic
286
+ glamorous
287
+ glorious
288
+ glossy
289
+ glowing
290
+ gorgeous
291
+ graceful
292
+ gracious
293
+ grand
294
+ granted
295
+ grateful
296
+ great
297
+ grilled
298
+ grounded
299
+ grown
300
+ guarded
301
+ guided
302
+ hailed
303
+ handsome
304
+ healing
305
+ healthy
306
+ heartfelt
307
+ heavenly
308
+ heroic
309
+ highly
310
+ historic
311
+ holistic
312
+ holy
313
+ honest
314
+ honored
315
+ hoped
316
+ hopeful
317
+ iconic
318
+ ideal
319
+ illuminated
320
+ illuminating
321
+ illumination
322
+ illustrious
323
+ imaginative
324
+ imagined
325
+ immense
326
+ immortal
327
+ imposing
328
+ impressive
329
+ improved
330
+ incredible
331
+ infinite
332
+ informed
333
+ ingenious
334
+ innocent
335
+ innovative
336
+ insightful
337
+ inspirational
338
+ inspired
339
+ inspiring
340
+ instructed
341
+ integrated
342
+ intense
343
+ intricate
344
+ intriguing
345
+ invaluable
346
+ invented
347
+ investigative
348
+ invincible
349
+ inviting
350
+ irresistible
351
+ joined
352
+ joyful
353
+ keen
354
+ kindly
355
+ kinetic
356
+ knockout
357
+ laced
358
+ lasting
359
+ lauded
360
+ lavish
361
+ legendary
362
+ lifted
363
+ light
364
+ limited
365
+ linked
366
+ lively
367
+ located
368
+ logical
369
+ loved
370
+ lovely
371
+ loving
372
+ loyal
373
+ lucid
374
+ lucky
375
+ lush
376
+ luxurious
377
+ luxury
378
+ magic
379
+ magical
380
+ magnificent
381
+ majestic
382
+ marked
383
+ marvelous
384
+ massive
385
+ matched
386
+ matured
387
+ meaningful
388
+ memorable
389
+ merged
390
+ merry
391
+ meticulous
392
+ mindful
393
+ miraculous
394
+ modern
395
+ modified
396
+ monstrous
397
+ monumental
398
+ motivated
399
+ motivational
400
+ moved
401
+ moving
402
+ mystical
403
+ mythical
404
+ naive
405
+ neat
406
+ new
407
+ nice
408
+ nifty
409
+ noble
410
+ notable
411
+ noteworthy
412
+ novel
413
+ nuanced
414
+ offered
415
+ open
416
+ optimal
417
+ optimistic
418
+ orderly
419
+ organized
420
+ original
421
+ originated
422
+ outstanding
423
+ overwhelming
424
+ paired
425
+ palpable
426
+ passionate
427
+ peaceful
428
+ perfect
429
+ perfected
430
+ perpetual
431
+ persistent
432
+ phenomenal
433
+ pious
434
+ pivotal
435
+ placed
436
+ planned
437
+ pleasant
438
+ pleased
439
+ pleasing
440
+ plentiful
441
+ plotted
442
+ plush
443
+ poetic
444
+ poignant
445
+ polished
446
+ positive
447
+ praised
448
+ precious
449
+ precise
450
+ premier
451
+ premium
452
+ presented
453
+ preserved
454
+ prestigious
455
+ pretty
456
+ priceless
457
+ prime
458
+ pristine
459
+ probing
460
+ productive
461
+ professional
462
+ profound
463
+ progressed
464
+ progressive
465
+ prominent
466
+ promoted
467
+ pronounced
468
+ propelled
469
+ proportional
470
+ prosperous
471
+ protected
472
+ provided
473
+ provocative
474
+ pure
475
+ pursued
476
+ pushed
477
+ quaint
478
+ quality
479
+ questioning
480
+ quiet
481
+ radiant
482
+ rare
483
+ rational
484
+ real
485
+ reborn
486
+ reclaimed
487
+ recognized
488
+ recovered
489
+ refined
490
+ reflected
491
+ refreshed
492
+ refreshing
493
+ related
494
+ relaxed
495
+ relentless
496
+ reliable
497
+ relieved
498
+ remarkable
499
+ renewed
500
+ renowned
501
+ representative
502
+ rescued
503
+ resilient
504
+ respected
505
+ respectful
506
+ restored
507
+ retrieved
508
+ revealed
509
+ revealing
510
+ revered
511
+ revived
512
+ rewarded
513
+ rich
514
+ roasted
515
+ robust
516
+ romantic
517
+ royal
518
+ sacred
519
+ salient
520
+ satisfied
521
+ satisfying
522
+ saturated
523
+ saved
524
+ scenic
525
+ scientific
526
+ select
527
+ sensational
528
+ serious
529
+ set
530
+ shaped
531
+ sharp
532
+ shielded
533
+ shining
534
+ shiny
535
+ shown
536
+ significant
537
+ silent
538
+ sincere
539
+ singular
540
+ situated
541
+ sleek
542
+ slick
543
+ smart
544
+ snug
545
+ solemn
546
+ solid
547
+ soothing
548
+ sophisticated
549
+ sought
550
+ sparkling
551
+ special
552
+ spectacular
553
+ sped
554
+ spirited
555
+ spiritual
556
+ splendid
557
+ spread
558
+ stable
559
+ steady
560
+ still
561
+ stimulated
562
+ stimulating
563
+ stirred
564
+ straightforward
565
+ striking
566
+ strong
567
+ structured
568
+ stunning
569
+ sturdy
570
+ stylish
571
+ sublime
572
+ successful
573
+ sunny
574
+ superb
575
+ superior
576
+ supplied
577
+ supported
578
+ supportive
579
+ supreme
580
+ sure
581
+ surreal
582
+ sweet
583
+ symbolic
584
+ symmetry
585
+ synchronized
586
+ systematic
587
+ tailored
588
+ taking
589
+ targeted
590
+ taught
591
+ tempting
592
+ tender
593
+ terrific
594
+ thankful
595
+ theatrical
596
+ thought
597
+ thoughtful
598
+ thrilled
599
+ thrilling
600
+ thriving
601
+ tidy
602
+ timeless
603
+ touching
604
+ tough
605
+ trained
606
+ tranquil
607
+ transformed
608
+ translucent
609
+ transparent
610
+ transported
611
+ tremendous
612
+ trendy
613
+ tried
614
+ trim
615
+ true
616
+ trustworthy
617
+ unbelievable
618
+ unconditional
619
+ uncovered
620
+ unified
621
+ unique
622
+ united
623
+ universal
624
+ unmatched
625
+ unparalleled
626
+ upheld
627
+ valiant
628
+ valued
629
+ varied
630
+ very
631
+ vibrant
632
+ virtuous
633
+ vivid
634
+ warm
635
+ wealthy
636
+ whole
637
+ winning
638
+ wished
639
+ witty
640
+ wonderful
641
+ worshipped
642
+ worthy
extras/fooocus_expansion/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd54cc90d95d2c72b97830e4b38f44a6521847284d5b9dbcfd16ba82779cdeb3
3
+ size 351283802
extras/fooocus_expansion/special_tokens_map.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<|endoftext|>",
3
+ "eos_token": "<|endoftext|>",
4
+ "unk_token": "<|endoftext|>"
5
+ }
extras/fooocus_expansion/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
extras/fooocus_expansion/tokenizer_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "eos_token": "<|endoftext|>",
5
+ "model_max_length": 1024,
6
+ "name_or_path": "gpt2",
7
+ "special_tokens_map_file": null,
8
+ "tokenizer_class": "GPT2Tokenizer",
9
+ "unk_token": "<|endoftext|>"
10
+ }
extras/fooocus_expansion/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
ldm_patched/contrib/external.py ADDED
@@ -0,0 +1,1954 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+
5
+ import os
6
+ import sys
7
+ import json
8
+ import hashlib
9
+ import traceback
10
+ import math
11
+ import time
12
+ import random
13
+
14
+ from PIL import Image, ImageOps, ImageSequence
15
+ from PIL.PngImagePlugin import PngInfo
16
+ import numpy as np
17
+ import safetensors.torch
18
+
19
+ pass # sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched"))
20
+
21
+
22
+ import ldm_patched.modules.diffusers_load
23
+ import ldm_patched.modules.samplers
24
+ import ldm_patched.modules.sample
25
+ import ldm_patched.modules.sd
26
+ import ldm_patched.modules.utils
27
+ import ldm_patched.modules.controlnet
28
+
29
+ import ldm_patched.modules.clip_vision
30
+
31
+ import ldm_patched.modules.model_management
32
+ from ldm_patched.modules.args_parser import args
33
+
34
+ import importlib
35
+
36
+ import ldm_patched.utils.path_utils
37
+ import ldm_patched.utils.latent_visualization
38
+
39
+ def before_node_execution():
40
+ ldm_patched.modules.model_management.throw_exception_if_processing_interrupted()
41
+
42
+ def interrupt_processing(value=True):
43
+ ldm_patched.modules.model_management.interrupt_current_processing(value)
44
+
45
+ MAX_RESOLUTION=8192
46
+
47
+ class CLIPTextEncode:
48
+ @classmethod
49
+ def INPUT_TYPES(s):
50
+ return {"required": {"text": ("STRING", {"multiline": True}), "clip": ("CLIP", )}}
51
+ RETURN_TYPES = ("CONDITIONING",)
52
+ FUNCTION = "encode"
53
+
54
+ CATEGORY = "conditioning"
55
+
56
+ def encode(self, clip, text):
57
+ tokens = clip.tokenize(text)
58
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
59
+ return ([[cond, {"pooled_output": pooled}]], )
60
+
61
+ class ConditioningCombine:
62
+ @classmethod
63
+ def INPUT_TYPES(s):
64
+ return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
65
+ RETURN_TYPES = ("CONDITIONING",)
66
+ FUNCTION = "combine"
67
+
68
+ CATEGORY = "conditioning"
69
+
70
+ def combine(self, conditioning_1, conditioning_2):
71
+ return (conditioning_1 + conditioning_2, )
72
+
73
+ class ConditioningAverage :
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ),
77
+ "conditioning_to_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
78
+ }}
79
+ RETURN_TYPES = ("CONDITIONING",)
80
+ FUNCTION = "addWeighted"
81
+
82
+ CATEGORY = "conditioning"
83
+
84
+ def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength):
85
+ out = []
86
+
87
+ if len(conditioning_from) > 1:
88
+ print("Warning: ConditioningAverage conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
89
+
90
+ cond_from = conditioning_from[0][0]
91
+ pooled_output_from = conditioning_from[0][1].get("pooled_output", None)
92
+
93
+ for i in range(len(conditioning_to)):
94
+ t1 = conditioning_to[i][0]
95
+ pooled_output_to = conditioning_to[i][1].get("pooled_output", pooled_output_from)
96
+ t0 = cond_from[:,:t1.shape[1]]
97
+ if t0.shape[1] < t1.shape[1]:
98
+ t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)
99
+
100
+ tw = torch.mul(t1, conditioning_to_strength) + torch.mul(t0, (1.0 - conditioning_to_strength))
101
+ t_to = conditioning_to[i][1].copy()
102
+ if pooled_output_from is not None and pooled_output_to is not None:
103
+ t_to["pooled_output"] = torch.mul(pooled_output_to, conditioning_to_strength) + torch.mul(pooled_output_from, (1.0 - conditioning_to_strength))
104
+ elif pooled_output_from is not None:
105
+ t_to["pooled_output"] = pooled_output_from
106
+
107
+ n = [tw, t_to]
108
+ out.append(n)
109
+ return (out, )
110
+
111
+ class ConditioningConcat:
112
+ @classmethod
113
+ def INPUT_TYPES(s):
114
+ return {"required": {
115
+ "conditioning_to": ("CONDITIONING",),
116
+ "conditioning_from": ("CONDITIONING",),
117
+ }}
118
+ RETURN_TYPES = ("CONDITIONING",)
119
+ FUNCTION = "concat"
120
+
121
+ CATEGORY = "conditioning"
122
+
123
+ def concat(self, conditioning_to, conditioning_from):
124
+ out = []
125
+
126
+ if len(conditioning_from) > 1:
127
+ print("Warning: ConditioningConcat conditioning_from contains more than 1 cond, only the first one will actually be applied to conditioning_to.")
128
+
129
+ cond_from = conditioning_from[0][0]
130
+
131
+ for i in range(len(conditioning_to)):
132
+ t1 = conditioning_to[i][0]
133
+ tw = torch.cat((t1, cond_from),1)
134
+ n = [tw, conditioning_to[i][1].copy()]
135
+ out.append(n)
136
+
137
+ return (out, )
138
+
139
+ class ConditioningSetArea:
140
+ @classmethod
141
+ def INPUT_TYPES(s):
142
+ return {"required": {"conditioning": ("CONDITIONING", ),
143
+ "width": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
144
+ "height": ("INT", {"default": 64, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
145
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
146
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
147
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
148
+ }}
149
+ RETURN_TYPES = ("CONDITIONING",)
150
+ FUNCTION = "append"
151
+
152
+ CATEGORY = "conditioning"
153
+
154
+ def append(self, conditioning, width, height, x, y, strength):
155
+ c = []
156
+ for t in conditioning:
157
+ n = [t[0], t[1].copy()]
158
+ n[1]['area'] = (height // 8, width // 8, y // 8, x // 8)
159
+ n[1]['strength'] = strength
160
+ n[1]['set_area_to_bounds'] = False
161
+ c.append(n)
162
+ return (c, )
163
+
164
+ class ConditioningSetAreaPercentage:
165
+ @classmethod
166
+ def INPUT_TYPES(s):
167
+ return {"required": {"conditioning": ("CONDITIONING", ),
168
+ "width": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
169
+ "height": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
170
+ "x": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
171
+ "y": ("FLOAT", {"default": 0, "min": 0, "max": 1.0, "step": 0.01}),
172
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
173
+ }}
174
+ RETURN_TYPES = ("CONDITIONING",)
175
+ FUNCTION = "append"
176
+
177
+ CATEGORY = "conditioning"
178
+
179
+ def append(self, conditioning, width, height, x, y, strength):
180
+ c = []
181
+ for t in conditioning:
182
+ n = [t[0], t[1].copy()]
183
+ n[1]['area'] = ("percentage", height, width, y, x)
184
+ n[1]['strength'] = strength
185
+ n[1]['set_area_to_bounds'] = False
186
+ c.append(n)
187
+ return (c, )
188
+
189
+ class ConditioningSetMask:
190
+ @classmethod
191
+ def INPUT_TYPES(s):
192
+ return {"required": {"conditioning": ("CONDITIONING", ),
193
+ "mask": ("MASK", ),
194
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
195
+ "set_cond_area": (["default", "mask bounds"],),
196
+ }}
197
+ RETURN_TYPES = ("CONDITIONING",)
198
+ FUNCTION = "append"
199
+
200
+ CATEGORY = "conditioning"
201
+
202
+ def append(self, conditioning, mask, set_cond_area, strength):
203
+ c = []
204
+ set_area_to_bounds = False
205
+ if set_cond_area != "default":
206
+ set_area_to_bounds = True
207
+ if len(mask.shape) < 3:
208
+ mask = mask.unsqueeze(0)
209
+ for t in conditioning:
210
+ n = [t[0], t[1].copy()]
211
+ _, h, w = mask.shape
212
+ n[1]['mask'] = mask
213
+ n[1]['set_area_to_bounds'] = set_area_to_bounds
214
+ n[1]['mask_strength'] = strength
215
+ c.append(n)
216
+ return (c, )
217
+
218
+ class ConditioningZeroOut:
219
+ @classmethod
220
+ def INPUT_TYPES(s):
221
+ return {"required": {"conditioning": ("CONDITIONING", )}}
222
+ RETURN_TYPES = ("CONDITIONING",)
223
+ FUNCTION = "zero_out"
224
+
225
+ CATEGORY = "advanced/conditioning"
226
+
227
+ def zero_out(self, conditioning):
228
+ c = []
229
+ for t in conditioning:
230
+ d = t[1].copy()
231
+ if "pooled_output" in d:
232
+ d["pooled_output"] = torch.zeros_like(d["pooled_output"])
233
+ n = [torch.zeros_like(t[0]), d]
234
+ c.append(n)
235
+ return (c, )
236
+
237
+ class ConditioningSetTimestepRange:
238
+ @classmethod
239
+ def INPUT_TYPES(s):
240
+ return {"required": {"conditioning": ("CONDITIONING", ),
241
+ "start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
242
+ "end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
243
+ }}
244
+ RETURN_TYPES = ("CONDITIONING",)
245
+ FUNCTION = "set_range"
246
+
247
+ CATEGORY = "advanced/conditioning"
248
+
249
+ def set_range(self, conditioning, start, end):
250
+ c = []
251
+ for t in conditioning:
252
+ d = t[1].copy()
253
+ d['start_percent'] = start
254
+ d['end_percent'] = end
255
+ n = [t[0], d]
256
+ c.append(n)
257
+ return (c, )
258
+
259
+ class VAEDecode:
260
+ @classmethod
261
+ def INPUT_TYPES(s):
262
+ return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
263
+ RETURN_TYPES = ("IMAGE",)
264
+ FUNCTION = "decode"
265
+
266
+ CATEGORY = "latent"
267
+
268
+ def decode(self, vae, samples):
269
+ return (vae.decode(samples["samples"]), )
270
+
271
+ class VAEDecodeTiled:
272
+ @classmethod
273
+ def INPUT_TYPES(s):
274
+ return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
275
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
276
+ }}
277
+ RETURN_TYPES = ("IMAGE",)
278
+ FUNCTION = "decode"
279
+
280
+ CATEGORY = "_for_testing"
281
+
282
+ def decode(self, vae, samples, tile_size):
283
+ return (vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, ), )
284
+
285
+ class VAEEncode:
286
+ @classmethod
287
+ def INPUT_TYPES(s):
288
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", )}}
289
+ RETURN_TYPES = ("LATENT",)
290
+ FUNCTION = "encode"
291
+
292
+ CATEGORY = "latent"
293
+
294
+ @staticmethod
295
+ def vae_encode_crop_pixels(pixels):
296
+ x = (pixels.shape[1] // 8) * 8
297
+ y = (pixels.shape[2] // 8) * 8
298
+ if pixels.shape[1] != x or pixels.shape[2] != y:
299
+ x_offset = (pixels.shape[1] % 8) // 2
300
+ y_offset = (pixels.shape[2] % 8) // 2
301
+ pixels = pixels[:, x_offset:x + x_offset, y_offset:y + y_offset, :]
302
+ return pixels
303
+
304
+ def encode(self, vae, pixels):
305
+ pixels = self.vae_encode_crop_pixels(pixels)
306
+ t = vae.encode(pixels[:,:,:,:3])
307
+ return ({"samples":t}, )
308
+
309
+ class VAEEncodeTiled:
310
+ @classmethod
311
+ def INPUT_TYPES(s):
312
+ return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
313
+ "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
314
+ }}
315
+ RETURN_TYPES = ("LATENT",)
316
+ FUNCTION = "encode"
317
+
318
+ CATEGORY = "_for_testing"
319
+
320
+ def encode(self, vae, pixels, tile_size):
321
+ pixels = VAEEncode.vae_encode_crop_pixels(pixels)
322
+ t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
323
+ return ({"samples":t}, )
324
+
325
+ class VAEEncodeForInpaint:
326
+ @classmethod
327
+ def INPUT_TYPES(s):
328
+ return {"required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "mask": ("MASK", ), "grow_mask_by": ("INT", {"default": 6, "min": 0, "max": 64, "step": 1}),}}
329
+ RETURN_TYPES = ("LATENT",)
330
+ FUNCTION = "encode"
331
+
332
+ CATEGORY = "latent/inpaint"
333
+
334
+ def encode(self, vae, pixels, mask, grow_mask_by=6):
335
+ x = (pixels.shape[1] // 8) * 8
336
+ y = (pixels.shape[2] // 8) * 8
337
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
338
+
339
+ pixels = pixels.clone()
340
+ if pixels.shape[1] != x or pixels.shape[2] != y:
341
+ x_offset = (pixels.shape[1] % 8) // 2
342
+ y_offset = (pixels.shape[2] % 8) // 2
343
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
344
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
345
+
346
+ #grow mask by a few pixels to keep things seamless in latent space
347
+ if grow_mask_by == 0:
348
+ mask_erosion = mask
349
+ else:
350
+ kernel_tensor = torch.ones((1, 1, grow_mask_by, grow_mask_by))
351
+ padding = math.ceil((grow_mask_by - 1) / 2)
352
+
353
+ mask_erosion = torch.clamp(torch.nn.functional.conv2d(mask.round(), kernel_tensor, padding=padding), 0, 1)
354
+
355
+ m = (1.0 - mask.round()).squeeze(1)
356
+ for i in range(3):
357
+ pixels[:,:,:,i] -= 0.5
358
+ pixels[:,:,:,i] *= m
359
+ pixels[:,:,:,i] += 0.5
360
+ t = vae.encode(pixels)
361
+
362
+ return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, )
363
+
364
+
365
+ class InpaintModelConditioning:
366
+ @classmethod
367
+ def INPUT_TYPES(s):
368
+ return {"required": {"positive": ("CONDITIONING", ),
369
+ "negative": ("CONDITIONING", ),
370
+ "vae": ("VAE", ),
371
+ "pixels": ("IMAGE", ),
372
+ "mask": ("MASK", ),
373
+ }}
374
+
375
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
376
+ RETURN_NAMES = ("positive", "negative", "latent")
377
+ FUNCTION = "encode"
378
+
379
+ CATEGORY = "conditioning/inpaint"
380
+
381
+ def encode(self, positive, negative, pixels, vae, mask):
382
+ x = (pixels.shape[1] // 8) * 8
383
+ y = (pixels.shape[2] // 8) * 8
384
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
385
+
386
+ orig_pixels = pixels
387
+ pixels = orig_pixels.clone()
388
+ if pixels.shape[1] != x or pixels.shape[2] != y:
389
+ x_offset = (pixels.shape[1] % 8) // 2
390
+ y_offset = (pixels.shape[2] % 8) // 2
391
+ pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
392
+ mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
393
+
394
+ m = (1.0 - mask.round()).squeeze(1)
395
+ for i in range(3):
396
+ pixels[:,:,:,i] -= 0.5
397
+ pixels[:,:,:,i] *= m
398
+ pixels[:,:,:,i] += 0.5
399
+ concat_latent = vae.encode(pixels)
400
+ orig_latent = vae.encode(orig_pixels)
401
+
402
+ out_latent = {}
403
+
404
+ out_latent["samples"] = orig_latent
405
+ out_latent["noise_mask"] = mask
406
+
407
+ out = []
408
+ for conditioning in [positive, negative]:
409
+ c = []
410
+ for t in conditioning:
411
+ d = t[1].copy()
412
+ d["concat_latent_image"] = concat_latent
413
+ d["concat_mask"] = mask
414
+ n = [t[0], d]
415
+ c.append(n)
416
+ out.append(c)
417
+ return (out[0], out[1], out_latent)
418
+
419
+
420
+ class SaveLatent:
421
+ def __init__(self):
422
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
423
+
424
+ @classmethod
425
+ def INPUT_TYPES(s):
426
+ return {"required": { "samples": ("LATENT", ),
427
+ "filename_prefix": ("STRING", {"default": "latents/ldm_patched"})},
428
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
429
+ }
430
+ RETURN_TYPES = ()
431
+ FUNCTION = "save"
432
+
433
+ OUTPUT_NODE = True
434
+
435
+ CATEGORY = "_for_testing"
436
+
437
+ def save(self, samples, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
438
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir)
439
+
440
+ # support save metadata for latent sharing
441
+ prompt_info = ""
442
+ if prompt is not None:
443
+ prompt_info = json.dumps(prompt)
444
+
445
+ metadata = None
446
+ if not args.disable_server_info:
447
+ metadata = {"prompt": prompt_info}
448
+ if extra_pnginfo is not None:
449
+ for x in extra_pnginfo:
450
+ metadata[x] = json.dumps(extra_pnginfo[x])
451
+
452
+ file = f"{filename}_{counter:05}_.latent"
453
+
454
+ results = list()
455
+ results.append({
456
+ "filename": file,
457
+ "subfolder": subfolder,
458
+ "type": "output"
459
+ })
460
+
461
+ file = os.path.join(full_output_folder, file)
462
+
463
+ output = {}
464
+ output["latent_tensor"] = samples["samples"]
465
+ output["latent_format_version_0"] = torch.tensor([])
466
+
467
+ ldm_patched.modules.utils.save_torch_file(output, file, metadata=metadata)
468
+ return { "ui": { "latents": results } }
469
+
470
+
471
+ class LoadLatent:
472
+ @classmethod
473
+ def INPUT_TYPES(s):
474
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
475
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
476
+ return {"required": {"latent": [sorted(files), ]}, }
477
+
478
+ CATEGORY = "_for_testing"
479
+
480
+ RETURN_TYPES = ("LATENT", )
481
+ FUNCTION = "load"
482
+
483
+ def load(self, latent):
484
+ latent_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent)
485
+ latent = safetensors.torch.load_file(latent_path, device="cpu")
486
+ multiplier = 1.0
487
+ if "latent_format_version_0" not in latent:
488
+ multiplier = 1.0 / 0.18215
489
+ samples = {"samples": latent["latent_tensor"].float() * multiplier}
490
+ return (samples, )
491
+
492
+ @classmethod
493
+ def IS_CHANGED(s, latent):
494
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(latent)
495
+ m = hashlib.sha256()
496
+ with open(image_path, 'rb') as f:
497
+ m.update(f.read())
498
+ return m.digest().hex()
499
+
500
+ @classmethod
501
+ def VALIDATE_INPUTS(s, latent):
502
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(latent):
503
+ return "Invalid latent file: {}".format(latent)
504
+ return True
505
+
506
+
507
+ class CheckpointLoader:
508
+ @classmethod
509
+ def INPUT_TYPES(s):
510
+ return {"required": { "config_name": (ldm_patched.utils.path_utils.get_filename_list("configs"), ),
511
+ "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), )}}
512
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
513
+ FUNCTION = "load_checkpoint"
514
+
515
+ CATEGORY = "advanced/loaders"
516
+
517
+ def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
518
+ config_path = ldm_patched.utils.path_utils.get_full_path("configs", config_name)
519
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
520
+ return ldm_patched.modules.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
521
+
522
+ class CheckpointLoaderSimple:
523
+ @classmethod
524
+ def INPUT_TYPES(s):
525
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
526
+ }}
527
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
528
+ FUNCTION = "load_checkpoint"
529
+
530
+ CATEGORY = "loaders"
531
+
532
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
533
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
534
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
535
+ return out[:3]
536
+
537
+ class DiffusersLoader:
538
+ @classmethod
539
+ def INPUT_TYPES(cls):
540
+ paths = []
541
+ for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"):
542
+ if os.path.exists(search_path):
543
+ for root, subdir, files in os.walk(search_path, followlinks=True):
544
+ if "model_index.json" in files:
545
+ paths.append(os.path.relpath(root, start=search_path))
546
+
547
+ return {"required": {"model_path": (paths,), }}
548
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE")
549
+ FUNCTION = "load_checkpoint"
550
+
551
+ CATEGORY = "advanced/loaders/deprecated"
552
+
553
+ def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
554
+ for search_path in ldm_patched.utils.path_utils.get_folder_paths("diffusers"):
555
+ if os.path.exists(search_path):
556
+ path = os.path.join(search_path, model_path)
557
+ if os.path.exists(path):
558
+ model_path = path
559
+ break
560
+
561
+ return ldm_patched.modules.diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
562
+
563
+
564
+ class unCLIPCheckpointLoader:
565
+ @classmethod
566
+ def INPUT_TYPES(s):
567
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
568
+ }}
569
+ RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
570
+ FUNCTION = "load_checkpoint"
571
+
572
+ CATEGORY = "loaders"
573
+
574
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
575
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
576
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
577
+ return out
578
+
579
+ class CLIPSetLastLayer:
580
+ @classmethod
581
+ def INPUT_TYPES(s):
582
+ return {"required": { "clip": ("CLIP", ),
583
+ "stop_at_clip_layer": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}),
584
+ }}
585
+ RETURN_TYPES = ("CLIP",)
586
+ FUNCTION = "set_last_layer"
587
+
588
+ CATEGORY = "conditioning"
589
+
590
+ def set_last_layer(self, clip, stop_at_clip_layer):
591
+ clip = clip.clone()
592
+ clip.clip_layer(stop_at_clip_layer)
593
+ return (clip,)
594
+
595
+ class LoraLoader:
596
+ def __init__(self):
597
+ self.loaded_lora = None
598
+
599
+ @classmethod
600
+ def INPUT_TYPES(s):
601
+ return {"required": { "model": ("MODEL",),
602
+ "clip": ("CLIP", ),
603
+ "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ),
604
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
605
+ "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
606
+ }}
607
+ RETURN_TYPES = ("MODEL", "CLIP")
608
+ FUNCTION = "load_lora"
609
+
610
+ CATEGORY = "loaders"
611
+
612
+ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
613
+ if strength_model == 0 and strength_clip == 0:
614
+ return (model, clip)
615
+
616
+ lora_path = ldm_patched.utils.path_utils.get_full_path("loras", lora_name)
617
+ lora = None
618
+ if self.loaded_lora is not None:
619
+ if self.loaded_lora[0] == lora_path:
620
+ lora = self.loaded_lora[1]
621
+ else:
622
+ temp = self.loaded_lora
623
+ self.loaded_lora = None
624
+ del temp
625
+
626
+ if lora is None:
627
+ lora = ldm_patched.modules.utils.load_torch_file(lora_path, safe_load=True)
628
+ self.loaded_lora = (lora_path, lora)
629
+
630
+ model_lora, clip_lora = ldm_patched.modules.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
631
+ return (model_lora, clip_lora)
632
+
633
+ class LoraLoaderModelOnly(LoraLoader):
634
+ @classmethod
635
+ def INPUT_TYPES(s):
636
+ return {"required": { "model": ("MODEL",),
637
+ "lora_name": (ldm_patched.utils.path_utils.get_filename_list("loras"), ),
638
+ "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
639
+ }}
640
+ RETURN_TYPES = ("MODEL",)
641
+ FUNCTION = "load_lora_model_only"
642
+
643
+ def load_lora_model_only(self, model, lora_name, strength_model):
644
+ return (self.load_lora(model, None, lora_name, strength_model, 0)[0],)
645
+
646
+ class VAELoader:
647
+ @staticmethod
648
+ def vae_list():
649
+ vaes = ldm_patched.utils.path_utils.get_filename_list("vae")
650
+ approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx")
651
+ sdxl_taesd_enc = False
652
+ sdxl_taesd_dec = False
653
+ sd1_taesd_enc = False
654
+ sd1_taesd_dec = False
655
+
656
+ for v in approx_vaes:
657
+ if v.startswith("taesd_decoder."):
658
+ sd1_taesd_dec = True
659
+ elif v.startswith("taesd_encoder."):
660
+ sd1_taesd_enc = True
661
+ elif v.startswith("taesdxl_decoder."):
662
+ sdxl_taesd_dec = True
663
+ elif v.startswith("taesdxl_encoder."):
664
+ sdxl_taesd_enc = True
665
+ if sd1_taesd_dec and sd1_taesd_enc:
666
+ vaes.append("taesd")
667
+ if sdxl_taesd_dec and sdxl_taesd_enc:
668
+ vaes.append("taesdxl")
669
+ return vaes
670
+
671
+ @staticmethod
672
+ def load_taesd(name):
673
+ sd = {}
674
+ approx_vaes = ldm_patched.utils.path_utils.get_filename_list("vae_approx")
675
+
676
+ encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
677
+ decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
678
+
679
+ enc = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", encoder))
680
+ for k in enc:
681
+ sd["taesd_encoder.{}".format(k)] = enc[k]
682
+
683
+ dec = ldm_patched.modules.utils.load_torch_file(ldm_patched.utils.path_utils.get_full_path("vae_approx", decoder))
684
+ for k in dec:
685
+ sd["taesd_decoder.{}".format(k)] = dec[k]
686
+
687
+ if name == "taesd":
688
+ sd["vae_scale"] = torch.tensor(0.18215)
689
+ elif name == "taesdxl":
690
+ sd["vae_scale"] = torch.tensor(0.13025)
691
+ return sd
692
+
693
+ @classmethod
694
+ def INPUT_TYPES(s):
695
+ return {"required": { "vae_name": (s.vae_list(), )}}
696
+ RETURN_TYPES = ("VAE",)
697
+ FUNCTION = "load_vae"
698
+
699
+ CATEGORY = "loaders"
700
+
701
+ #TODO: scale factor?
702
+ def load_vae(self, vae_name):
703
+ if vae_name in ["taesd", "taesdxl"]:
704
+ sd = self.load_taesd(vae_name)
705
+ else:
706
+ vae_path = ldm_patched.utils.path_utils.get_full_path("vae", vae_name)
707
+ sd = ldm_patched.modules.utils.load_torch_file(vae_path)
708
+ vae = ldm_patched.modules.sd.VAE(sd=sd)
709
+ return (vae,)
710
+
711
+ class ControlNetLoader:
712
+ @classmethod
713
+ def INPUT_TYPES(s):
714
+ return {"required": { "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}}
715
+
716
+ RETURN_TYPES = ("CONTROL_NET",)
717
+ FUNCTION = "load_controlnet"
718
+
719
+ CATEGORY = "loaders"
720
+
721
+ def load_controlnet(self, control_net_name):
722
+ controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name)
723
+ controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path)
724
+ return (controlnet,)
725
+
726
+ class DiffControlNetLoader:
727
+ @classmethod
728
+ def INPUT_TYPES(s):
729
+ return {"required": { "model": ("MODEL",),
730
+ "control_net_name": (ldm_patched.utils.path_utils.get_filename_list("controlnet"), )}}
731
+
732
+ RETURN_TYPES = ("CONTROL_NET",)
733
+ FUNCTION = "load_controlnet"
734
+
735
+ CATEGORY = "loaders"
736
+
737
+ def load_controlnet(self, model, control_net_name):
738
+ controlnet_path = ldm_patched.utils.path_utils.get_full_path("controlnet", control_net_name)
739
+ controlnet = ldm_patched.modules.controlnet.load_controlnet(controlnet_path, model)
740
+ return (controlnet,)
741
+
742
+
743
+ class ControlNetApply:
744
+ @classmethod
745
+ def INPUT_TYPES(s):
746
+ return {"required": {"conditioning": ("CONDITIONING", ),
747
+ "control_net": ("CONTROL_NET", ),
748
+ "image": ("IMAGE", ),
749
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01})
750
+ }}
751
+ RETURN_TYPES = ("CONDITIONING",)
752
+ FUNCTION = "apply_controlnet"
753
+
754
+ CATEGORY = "conditioning"
755
+
756
+ def apply_controlnet(self, conditioning, control_net, image, strength):
757
+ if strength == 0:
758
+ return (conditioning, )
759
+
760
+ c = []
761
+ control_hint = image.movedim(-1,1)
762
+ for t in conditioning:
763
+ n = [t[0], t[1].copy()]
764
+ c_net = control_net.copy().set_cond_hint(control_hint, strength)
765
+ if 'control' in t[1]:
766
+ c_net.set_previous_controlnet(t[1]['control'])
767
+ n[1]['control'] = c_net
768
+ n[1]['control_apply_to_uncond'] = True
769
+ c.append(n)
770
+ return (c, )
771
+
772
+
773
+ class ControlNetApplyAdvanced:
774
+ @classmethod
775
+ def INPUT_TYPES(s):
776
+ return {"required": {"positive": ("CONDITIONING", ),
777
+ "negative": ("CONDITIONING", ),
778
+ "control_net": ("CONTROL_NET", ),
779
+ "image": ("IMAGE", ),
780
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
781
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
782
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
783
+ }}
784
+
785
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING")
786
+ RETURN_NAMES = ("positive", "negative")
787
+ FUNCTION = "apply_controlnet"
788
+
789
+ CATEGORY = "conditioning"
790
+
791
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent):
792
+ if strength == 0:
793
+ return (positive, negative)
794
+
795
+ control_hint = image.movedim(-1,1)
796
+ cnets = {}
797
+
798
+ out = []
799
+ for conditioning in [positive, negative]:
800
+ c = []
801
+ for t in conditioning:
802
+ d = t[1].copy()
803
+
804
+ prev_cnet = d.get('control', None)
805
+ if prev_cnet in cnets:
806
+ c_net = cnets[prev_cnet]
807
+ else:
808
+ c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
809
+ c_net.set_previous_controlnet(prev_cnet)
810
+ cnets[prev_cnet] = c_net
811
+
812
+ d['control'] = c_net
813
+ d['control_apply_to_uncond'] = False
814
+ n = [t[0], d]
815
+ c.append(n)
816
+ out.append(c)
817
+ return (out[0], out[1])
818
+
819
+
820
+ class UNETLoader:
821
+ @classmethod
822
+ def INPUT_TYPES(s):
823
+ return {"required": { "unet_name": (ldm_patched.utils.path_utils.get_filename_list("unet"), ),
824
+ }}
825
+ RETURN_TYPES = ("MODEL",)
826
+ FUNCTION = "load_unet"
827
+
828
+ CATEGORY = "advanced/loaders"
829
+
830
+ def load_unet(self, unet_name):
831
+ unet_path = ldm_patched.utils.path_utils.get_full_path("unet", unet_name)
832
+ model = ldm_patched.modules.sd.load_unet(unet_path)
833
+ return (model,)
834
+
835
+ class CLIPLoader:
836
+ @classmethod
837
+ def INPUT_TYPES(s):
838
+ return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip"), ),
839
+ }}
840
+ RETURN_TYPES = ("CLIP",)
841
+ FUNCTION = "load_clip"
842
+
843
+ CATEGORY = "advanced/loaders"
844
+
845
+ def load_clip(self, clip_name):
846
+ clip_path = ldm_patched.utils.path_utils.get_full_path("clip", clip_name)
847
+ clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
848
+ return (clip,)
849
+
850
+ class DualCLIPLoader:
851
+ @classmethod
852
+ def INPUT_TYPES(s):
853
+ return {"required": { "clip_name1": (ldm_patched.utils.path_utils.get_filename_list("clip"), ), "clip_name2": (ldm_patched.utils.path_utils.get_filename_list("clip"), ),
854
+ }}
855
+ RETURN_TYPES = ("CLIP",)
856
+ FUNCTION = "load_clip"
857
+
858
+ CATEGORY = "advanced/loaders"
859
+
860
+ def load_clip(self, clip_name1, clip_name2):
861
+ clip_path1 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name1)
862
+ clip_path2 = ldm_patched.utils.path_utils.get_full_path("clip", clip_name2)
863
+ clip = ldm_patched.modules.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
864
+ return (clip,)
865
+
866
+ class CLIPVisionLoader:
867
+ @classmethod
868
+ def INPUT_TYPES(s):
869
+ return {"required": { "clip_name": (ldm_patched.utils.path_utils.get_filename_list("clip_vision"), ),
870
+ }}
871
+ RETURN_TYPES = ("CLIP_VISION",)
872
+ FUNCTION = "load_clip"
873
+
874
+ CATEGORY = "loaders"
875
+
876
+ def load_clip(self, clip_name):
877
+ clip_path = ldm_patched.utils.path_utils.get_full_path("clip_vision", clip_name)
878
+ clip_vision = ldm_patched.modules.clip_vision.load(clip_path)
879
+ return (clip_vision,)
880
+
881
+ class CLIPVisionEncode:
882
+ @classmethod
883
+ def INPUT_TYPES(s):
884
+ return {"required": { "clip_vision": ("CLIP_VISION",),
885
+ "image": ("IMAGE",)
886
+ }}
887
+ RETURN_TYPES = ("CLIP_VISION_OUTPUT",)
888
+ FUNCTION = "encode"
889
+
890
+ CATEGORY = "conditioning"
891
+
892
+ def encode(self, clip_vision, image):
893
+ output = clip_vision.encode_image(image)
894
+ return (output,)
895
+
896
+ class StyleModelLoader:
897
+ @classmethod
898
+ def INPUT_TYPES(s):
899
+ return {"required": { "style_model_name": (ldm_patched.utils.path_utils.get_filename_list("style_models"), )}}
900
+
901
+ RETURN_TYPES = ("STYLE_MODEL",)
902
+ FUNCTION = "load_style_model"
903
+
904
+ CATEGORY = "loaders"
905
+
906
+ def load_style_model(self, style_model_name):
907
+ style_model_path = ldm_patched.utils.path_utils.get_full_path("style_models", style_model_name)
908
+ style_model = ldm_patched.modules.sd.load_style_model(style_model_path)
909
+ return (style_model,)
910
+
911
+
912
+ class StyleModelApply:
913
+ @classmethod
914
+ def INPUT_TYPES(s):
915
+ return {"required": {"conditioning": ("CONDITIONING", ),
916
+ "style_model": ("STYLE_MODEL", ),
917
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
918
+ }}
919
+ RETURN_TYPES = ("CONDITIONING",)
920
+ FUNCTION = "apply_stylemodel"
921
+
922
+ CATEGORY = "conditioning/style_model"
923
+
924
+ def apply_stylemodel(self, clip_vision_output, style_model, conditioning):
925
+ cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
926
+ c = []
927
+ for t in conditioning:
928
+ n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
929
+ c.append(n)
930
+ return (c, )
931
+
932
+ class unCLIPConditioning:
933
+ @classmethod
934
+ def INPUT_TYPES(s):
935
+ return {"required": {"conditioning": ("CONDITIONING", ),
936
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
937
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
938
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
939
+ }}
940
+ RETURN_TYPES = ("CONDITIONING",)
941
+ FUNCTION = "apply_adm"
942
+
943
+ CATEGORY = "conditioning"
944
+
945
+ def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation):
946
+ if strength == 0:
947
+ return (conditioning, )
948
+
949
+ c = []
950
+ for t in conditioning:
951
+ o = t[1].copy()
952
+ x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
953
+ if "unclip_conditioning" in o:
954
+ o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
955
+ else:
956
+ o["unclip_conditioning"] = [x]
957
+ n = [t[0], o]
958
+ c.append(n)
959
+ return (c, )
960
+
961
+ class GLIGENLoader:
962
+ @classmethod
963
+ def INPUT_TYPES(s):
964
+ return {"required": { "gligen_name": (ldm_patched.utils.path_utils.get_filename_list("gligen"), )}}
965
+
966
+ RETURN_TYPES = ("GLIGEN",)
967
+ FUNCTION = "load_gligen"
968
+
969
+ CATEGORY = "loaders"
970
+
971
+ def load_gligen(self, gligen_name):
972
+ gligen_path = ldm_patched.utils.path_utils.get_full_path("gligen", gligen_name)
973
+ gligen = ldm_patched.modules.sd.load_gligen(gligen_path)
974
+ return (gligen,)
975
+
976
+ class GLIGENTextBoxApply:
977
+ @classmethod
978
+ def INPUT_TYPES(s):
979
+ return {"required": {"conditioning_to": ("CONDITIONING", ),
980
+ "clip": ("CLIP", ),
981
+ "gligen_textbox_model": ("GLIGEN", ),
982
+ "text": ("STRING", {"multiline": True}),
983
+ "width": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
984
+ "height": ("INT", {"default": 64, "min": 8, "max": MAX_RESOLUTION, "step": 8}),
985
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
986
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
987
+ }}
988
+ RETURN_TYPES = ("CONDITIONING",)
989
+ FUNCTION = "append"
990
+
991
+ CATEGORY = "conditioning/gligen"
992
+
993
+ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y):
994
+ c = []
995
+ cond, cond_pooled = clip.encode_from_tokens(clip.tokenize(text), return_pooled=True)
996
+ for t in conditioning_to:
997
+ n = [t[0], t[1].copy()]
998
+ position_params = [(cond_pooled, height // 8, width // 8, y // 8, x // 8)]
999
+ prev = []
1000
+ if "gligen" in n[1]:
1001
+ prev = n[1]['gligen'][2]
1002
+
1003
+ n[1]['gligen'] = ("position", gligen_textbox_model, prev + position_params)
1004
+ c.append(n)
1005
+ return (c, )
1006
+
1007
+ class EmptyLatentImage:
1008
+ def __init__(self):
1009
+ self.device = ldm_patched.modules.model_management.intermediate_device()
1010
+
1011
+ @classmethod
1012
+ def INPUT_TYPES(s):
1013
+ return {"required": { "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1014
+ "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
1015
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
1016
+ RETURN_TYPES = ("LATENT",)
1017
+ FUNCTION = "generate"
1018
+
1019
+ CATEGORY = "latent"
1020
+
1021
+ def generate(self, width, height, batch_size=1):
1022
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1023
+ return ({"samples":latent}, )
1024
+
1025
+
1026
+ class LatentFromBatch:
1027
+ @classmethod
1028
+ def INPUT_TYPES(s):
1029
+ return {"required": { "samples": ("LATENT",),
1030
+ "batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
1031
+ "length": ("INT", {"default": 1, "min": 1, "max": 64}),
1032
+ }}
1033
+ RETURN_TYPES = ("LATENT",)
1034
+ FUNCTION = "frombatch"
1035
+
1036
+ CATEGORY = "latent/batch"
1037
+
1038
+ def frombatch(self, samples, batch_index, length):
1039
+ s = samples.copy()
1040
+ s_in = samples["samples"]
1041
+ batch_index = min(s_in.shape[0] - 1, batch_index)
1042
+ length = min(s_in.shape[0] - batch_index, length)
1043
+ s["samples"] = s_in[batch_index:batch_index + length].clone()
1044
+ if "noise_mask" in samples:
1045
+ masks = samples["noise_mask"]
1046
+ if masks.shape[0] == 1:
1047
+ s["noise_mask"] = masks.clone()
1048
+ else:
1049
+ if masks.shape[0] < s_in.shape[0]:
1050
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1051
+ s["noise_mask"] = masks[batch_index:batch_index + length].clone()
1052
+ if "batch_index" not in s:
1053
+ s["batch_index"] = [x for x in range(batch_index, batch_index+length)]
1054
+ else:
1055
+ s["batch_index"] = samples["batch_index"][batch_index:batch_index + length]
1056
+ return (s,)
1057
+
1058
+ class RepeatLatentBatch:
1059
+ @classmethod
1060
+ def INPUT_TYPES(s):
1061
+ return {"required": { "samples": ("LATENT",),
1062
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
1063
+ }}
1064
+ RETURN_TYPES = ("LATENT",)
1065
+ FUNCTION = "repeat"
1066
+
1067
+ CATEGORY = "latent/batch"
1068
+
1069
+ def repeat(self, samples, amount):
1070
+ s = samples.copy()
1071
+ s_in = samples["samples"]
1072
+
1073
+ s["samples"] = s_in.repeat((amount, 1,1,1))
1074
+ if "noise_mask" in samples and samples["noise_mask"].shape[0] > 1:
1075
+ masks = samples["noise_mask"]
1076
+ if masks.shape[0] < s_in.shape[0]:
1077
+ masks = masks.repeat(math.ceil(s_in.shape[0] / masks.shape[0]), 1, 1, 1)[:s_in.shape[0]]
1078
+ s["noise_mask"] = samples["noise_mask"].repeat((amount, 1,1,1))
1079
+ if "batch_index" in s:
1080
+ offset = max(s["batch_index"]) - min(s["batch_index"]) + 1
1081
+ s["batch_index"] = s["batch_index"] + [x + (i * offset) for i in range(1, amount) for x in s["batch_index"]]
1082
+ return (s,)
1083
+
1084
+ class LatentUpscale:
1085
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1086
+ crop_methods = ["disabled", "center"]
1087
+
1088
+ @classmethod
1089
+ def INPUT_TYPES(s):
1090
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1091
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1092
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1093
+ "crop": (s.crop_methods,)}}
1094
+ RETURN_TYPES = ("LATENT",)
1095
+ FUNCTION = "upscale"
1096
+
1097
+ CATEGORY = "latent"
1098
+
1099
+ def upscale(self, samples, upscale_method, width, height, crop):
1100
+ if width == 0 and height == 0:
1101
+ s = samples
1102
+ else:
1103
+ s = samples.copy()
1104
+
1105
+ if width == 0:
1106
+ height = max(64, height)
1107
+ width = max(64, round(samples["samples"].shape[3] * height / samples["samples"].shape[2]))
1108
+ elif height == 0:
1109
+ width = max(64, width)
1110
+ height = max(64, round(samples["samples"].shape[2] * width / samples["samples"].shape[3]))
1111
+ else:
1112
+ width = max(64, width)
1113
+ height = max(64, height)
1114
+
1115
+ s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop)
1116
+ return (s,)
1117
+
1118
+ class LatentUpscaleBy:
1119
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"]
1120
+
1121
+ @classmethod
1122
+ def INPUT_TYPES(s):
1123
+ return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,),
1124
+ "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1125
+ RETURN_TYPES = ("LATENT",)
1126
+ FUNCTION = "upscale"
1127
+
1128
+ CATEGORY = "latent"
1129
+
1130
+ def upscale(self, samples, upscale_method, scale_by):
1131
+ s = samples.copy()
1132
+ width = round(samples["samples"].shape[3] * scale_by)
1133
+ height = round(samples["samples"].shape[2] * scale_by)
1134
+ s["samples"] = ldm_patched.modules.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled")
1135
+ return (s,)
1136
+
1137
+ class LatentRotate:
1138
+ @classmethod
1139
+ def INPUT_TYPES(s):
1140
+ return {"required": { "samples": ("LATENT",),
1141
+ "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],),
1142
+ }}
1143
+ RETURN_TYPES = ("LATENT",)
1144
+ FUNCTION = "rotate"
1145
+
1146
+ CATEGORY = "latent/transform"
1147
+
1148
+ def rotate(self, samples, rotation):
1149
+ s = samples.copy()
1150
+ rotate_by = 0
1151
+ if rotation.startswith("90"):
1152
+ rotate_by = 1
1153
+ elif rotation.startswith("180"):
1154
+ rotate_by = 2
1155
+ elif rotation.startswith("270"):
1156
+ rotate_by = 3
1157
+
1158
+ s["samples"] = torch.rot90(samples["samples"], k=rotate_by, dims=[3, 2])
1159
+ return (s,)
1160
+
1161
+ class LatentFlip:
1162
+ @classmethod
1163
+ def INPUT_TYPES(s):
1164
+ return {"required": { "samples": ("LATENT",),
1165
+ "flip_method": (["x-axis: vertically", "y-axis: horizontally"],),
1166
+ }}
1167
+ RETURN_TYPES = ("LATENT",)
1168
+ FUNCTION = "flip"
1169
+
1170
+ CATEGORY = "latent/transform"
1171
+
1172
+ def flip(self, samples, flip_method):
1173
+ s = samples.copy()
1174
+ if flip_method.startswith("x"):
1175
+ s["samples"] = torch.flip(samples["samples"], dims=[2])
1176
+ elif flip_method.startswith("y"):
1177
+ s["samples"] = torch.flip(samples["samples"], dims=[3])
1178
+
1179
+ return (s,)
1180
+
1181
+ class LatentComposite:
1182
+ @classmethod
1183
+ def INPUT_TYPES(s):
1184
+ return {"required": { "samples_to": ("LATENT",),
1185
+ "samples_from": ("LATENT",),
1186
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1187
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1188
+ "feather": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1189
+ }}
1190
+ RETURN_TYPES = ("LATENT",)
1191
+ FUNCTION = "composite"
1192
+
1193
+ CATEGORY = "latent"
1194
+
1195
+ def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0):
1196
+ x = x // 8
1197
+ y = y // 8
1198
+ feather = feather // 8
1199
+ samples_out = samples_to.copy()
1200
+ s = samples_to["samples"].clone()
1201
+ samples_to = samples_to["samples"]
1202
+ samples_from = samples_from["samples"]
1203
+ if feather == 0:
1204
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1205
+ else:
1206
+ samples_from = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x]
1207
+ mask = torch.ones_like(samples_from)
1208
+ for t in range(feather):
1209
+ if y != 0:
1210
+ mask[:,:,t:1+t,:] *= ((1.0/feather) * (t + 1))
1211
+
1212
+ if y + samples_from.shape[2] < samples_to.shape[2]:
1213
+ mask[:,:,mask.shape[2] -1 -t: mask.shape[2]-t,:] *= ((1.0/feather) * (t + 1))
1214
+ if x != 0:
1215
+ mask[:,:,:,t:1+t] *= ((1.0/feather) * (t + 1))
1216
+ if x + samples_from.shape[3] < samples_to.shape[3]:
1217
+ mask[:,:,:,mask.shape[3]- 1 - t: mask.shape[3]- t] *= ((1.0/feather) * (t + 1))
1218
+ rev_mask = torch.ones_like(mask) - mask
1219
+ s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] = samples_from[:,:,:samples_to.shape[2] - y, :samples_to.shape[3] - x] * mask + s[:,:,y:y+samples_from.shape[2],x:x+samples_from.shape[3]] * rev_mask
1220
+ samples_out["samples"] = s
1221
+ return (samples_out,)
1222
+
1223
+ class LatentBlend:
1224
+ @classmethod
1225
+ def INPUT_TYPES(s):
1226
+ return {"required": {
1227
+ "samples1": ("LATENT",),
1228
+ "samples2": ("LATENT",),
1229
+ "blend_factor": ("FLOAT", {
1230
+ "default": 0.5,
1231
+ "min": 0,
1232
+ "max": 1,
1233
+ "step": 0.01
1234
+ }),
1235
+ }}
1236
+
1237
+ RETURN_TYPES = ("LATENT",)
1238
+ FUNCTION = "blend"
1239
+
1240
+ CATEGORY = "_for_testing"
1241
+
1242
+ def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
1243
+
1244
+ samples_out = samples1.copy()
1245
+ samples1 = samples1["samples"]
1246
+ samples2 = samples2["samples"]
1247
+
1248
+ if samples1.shape != samples2.shape:
1249
+ samples2.permute(0, 3, 1, 2)
1250
+ samples2 = ldm_patched.modules.utils.common_upscale(samples2, samples1.shape[3], samples1.shape[2], 'bicubic', crop='center')
1251
+ samples2.permute(0, 2, 3, 1)
1252
+
1253
+ samples_blended = self.blend_mode(samples1, samples2, blend_mode)
1254
+ samples_blended = samples1 * blend_factor + samples_blended * (1 - blend_factor)
1255
+ samples_out["samples"] = samples_blended
1256
+ return (samples_out,)
1257
+
1258
+ def blend_mode(self, img1, img2, mode):
1259
+ if mode == "normal":
1260
+ return img2
1261
+ else:
1262
+ raise ValueError(f"Unsupported blend mode: {mode}")
1263
+
1264
+ class LatentCrop:
1265
+ @classmethod
1266
+ def INPUT_TYPES(s):
1267
+ return {"required": { "samples": ("LATENT",),
1268
+ "width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1269
+ "height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 8}),
1270
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1271
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1272
+ }}
1273
+ RETURN_TYPES = ("LATENT",)
1274
+ FUNCTION = "crop"
1275
+
1276
+ CATEGORY = "latent/transform"
1277
+
1278
+ def crop(self, samples, width, height, x, y):
1279
+ s = samples.copy()
1280
+ samples = samples['samples']
1281
+ x = x // 8
1282
+ y = y // 8
1283
+
1284
+ #enfonce minimum size of 64
1285
+ if x > (samples.shape[3] - 8):
1286
+ x = samples.shape[3] - 8
1287
+ if y > (samples.shape[2] - 8):
1288
+ y = samples.shape[2] - 8
1289
+
1290
+ new_height = height // 8
1291
+ new_width = width // 8
1292
+ to_x = new_width + x
1293
+ to_y = new_height + y
1294
+ s['samples'] = samples[:,:,y:to_y, x:to_x]
1295
+ return (s,)
1296
+
1297
+ class SetLatentNoiseMask:
1298
+ @classmethod
1299
+ def INPUT_TYPES(s):
1300
+ return {"required": { "samples": ("LATENT",),
1301
+ "mask": ("MASK",),
1302
+ }}
1303
+ RETURN_TYPES = ("LATENT",)
1304
+ FUNCTION = "set_mask"
1305
+
1306
+ CATEGORY = "latent/inpaint"
1307
+
1308
+ def set_mask(self, samples, mask):
1309
+ s = samples.copy()
1310
+ s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
1311
+ return (s,)
1312
+
1313
+ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
1314
+ latent_image = latent["samples"]
1315
+ if disable_noise:
1316
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
1317
+ else:
1318
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
1319
+ noise = ldm_patched.modules.sample.prepare_noise(latent_image, seed, batch_inds)
1320
+
1321
+ noise_mask = None
1322
+ if "noise_mask" in latent:
1323
+ noise_mask = latent["noise_mask"]
1324
+
1325
+ callback = ldm_patched.utils.latent_visualization.prepare_callback(model, steps)
1326
+ disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED
1327
+ samples = ldm_patched.modules.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
1328
+ denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
1329
+ force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
1330
+ out = latent.copy()
1331
+ out["samples"] = samples
1332
+ return (out, )
1333
+
1334
+ class KSampler:
1335
+ @classmethod
1336
+ def INPUT_TYPES(s):
1337
+ return {"required":
1338
+ {"model": ("MODEL",),
1339
+ "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1340
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1341
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1342
+ "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ),
1343
+ "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ),
1344
+ "positive": ("CONDITIONING", ),
1345
+ "negative": ("CONDITIONING", ),
1346
+ "latent_image": ("LATENT", ),
1347
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
1348
+ }
1349
+ }
1350
+
1351
+ RETURN_TYPES = ("LATENT",)
1352
+ FUNCTION = "sample"
1353
+
1354
+ CATEGORY = "sampling"
1355
+
1356
+ def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0):
1357
+ return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
1358
+
1359
+ class KSamplerAdvanced:
1360
+ @classmethod
1361
+ def INPUT_TYPES(s):
1362
+ return {"required":
1363
+ {"model": ("MODEL",),
1364
+ "add_noise": (["enable", "disable"], ),
1365
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
1366
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
1367
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
1368
+ "sampler_name": (ldm_patched.modules.samplers.KSampler.SAMPLERS, ),
1369
+ "scheduler": (ldm_patched.modules.samplers.KSampler.SCHEDULERS, ),
1370
+ "positive": ("CONDITIONING", ),
1371
+ "negative": ("CONDITIONING", ),
1372
+ "latent_image": ("LATENT", ),
1373
+ "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
1374
+ "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
1375
+ "return_with_leftover_noise": (["disable", "enable"], ),
1376
+ }
1377
+ }
1378
+
1379
+ RETURN_TYPES = ("LATENT",)
1380
+ FUNCTION = "sample"
1381
+
1382
+ CATEGORY = "sampling"
1383
+
1384
+ def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0):
1385
+ force_full_denoise = True
1386
+ if return_with_leftover_noise == "enable":
1387
+ force_full_denoise = False
1388
+ disable_noise = False
1389
+ if add_noise == "disable":
1390
+ disable_noise = True
1391
+ return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)
1392
+
1393
+ class SaveImage:
1394
+ def __init__(self):
1395
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
1396
+ self.type = "output"
1397
+ self.prefix_append = ""
1398
+ self.compress_level = 4
1399
+
1400
+ @classmethod
1401
+ def INPUT_TYPES(s):
1402
+ return {"required":
1403
+ {"images": ("IMAGE", ),
1404
+ "filename_prefix": ("STRING", {"default": "ldm_patched"})},
1405
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1406
+ }
1407
+
1408
+ RETURN_TYPES = ()
1409
+ FUNCTION = "save_images"
1410
+
1411
+ OUTPUT_NODE = True
1412
+
1413
+ CATEGORY = "image"
1414
+
1415
+ def save_images(self, images, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
1416
+ filename_prefix += self.prefix_append
1417
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
1418
+ results = list()
1419
+ for image in images:
1420
+ i = 255. * image.cpu().numpy()
1421
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
1422
+ metadata = None
1423
+ if not args.disable_server_info:
1424
+ metadata = PngInfo()
1425
+ if prompt is not None:
1426
+ metadata.add_text("prompt", json.dumps(prompt))
1427
+ if extra_pnginfo is not None:
1428
+ for x in extra_pnginfo:
1429
+ metadata.add_text(x, json.dumps(extra_pnginfo[x]))
1430
+
1431
+ file = f"{filename}_{counter:05}_.png"
1432
+ img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
1433
+ results.append({
1434
+ "filename": file,
1435
+ "subfolder": subfolder,
1436
+ "type": self.type
1437
+ })
1438
+ counter += 1
1439
+
1440
+ return { "ui": { "images": results } }
1441
+
1442
+ class PreviewImage(SaveImage):
1443
+ def __init__(self):
1444
+ self.output_dir = ldm_patched.utils.path_utils.get_temp_directory()
1445
+ self.type = "temp"
1446
+ self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
1447
+ self.compress_level = 1
1448
+
1449
+ @classmethod
1450
+ def INPUT_TYPES(s):
1451
+ return {"required":
1452
+ {"images": ("IMAGE", ), },
1453
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
1454
+ }
1455
+
1456
+ class LoadImage:
1457
+ @classmethod
1458
+ def INPUT_TYPES(s):
1459
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
1460
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1461
+ return {"required":
1462
+ {"image": (sorted(files), {"image_upload": True})},
1463
+ }
1464
+
1465
+ CATEGORY = "image"
1466
+
1467
+ RETURN_TYPES = ("IMAGE", "MASK")
1468
+ FUNCTION = "load_image"
1469
+ def load_image(self, image):
1470
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1471
+ img = Image.open(image_path)
1472
+ output_images = []
1473
+ output_masks = []
1474
+ for i in ImageSequence.Iterator(img):
1475
+ i = ImageOps.exif_transpose(i)
1476
+ if i.mode == 'I':
1477
+ i = i.point(lambda i: i * (1 / 255))
1478
+ image = i.convert("RGB")
1479
+ image = np.array(image).astype(np.float32) / 255.0
1480
+ image = torch.from_numpy(image)[None,]
1481
+ if 'A' in i.getbands():
1482
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1483
+ mask = 1. - torch.from_numpy(mask)
1484
+ else:
1485
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1486
+ output_images.append(image)
1487
+ output_masks.append(mask.unsqueeze(0))
1488
+
1489
+ if len(output_images) > 1:
1490
+ output_image = torch.cat(output_images, dim=0)
1491
+ output_mask = torch.cat(output_masks, dim=0)
1492
+ else:
1493
+ output_image = output_images[0]
1494
+ output_mask = output_masks[0]
1495
+
1496
+ return (output_image, output_mask)
1497
+
1498
+ @classmethod
1499
+ def IS_CHANGED(s, image):
1500
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1501
+ m = hashlib.sha256()
1502
+ with open(image_path, 'rb') as f:
1503
+ m.update(f.read())
1504
+ return m.digest().hex()
1505
+
1506
+ @classmethod
1507
+ def VALIDATE_INPUTS(s, image):
1508
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(image):
1509
+ return "Invalid image file: {}".format(image)
1510
+
1511
+ return True
1512
+
1513
+ class LoadImageMask:
1514
+ _color_channels = ["alpha", "red", "green", "blue"]
1515
+ @classmethod
1516
+ def INPUT_TYPES(s):
1517
+ input_dir = ldm_patched.utils.path_utils.get_input_directory()
1518
+ files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
1519
+ return {"required":
1520
+ {"image": (sorted(files), {"image_upload": True}),
1521
+ "channel": (s._color_channels, ), }
1522
+ }
1523
+
1524
+ CATEGORY = "mask"
1525
+
1526
+ RETURN_TYPES = ("MASK",)
1527
+ FUNCTION = "load_image"
1528
+ def load_image(self, image, channel):
1529
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1530
+ i = Image.open(image_path)
1531
+ i = ImageOps.exif_transpose(i)
1532
+ if i.getbands() != ("R", "G", "B", "A"):
1533
+ if i.mode == 'I':
1534
+ i = i.point(lambda i: i * (1 / 255))
1535
+ i = i.convert("RGBA")
1536
+ mask = None
1537
+ c = channel[0].upper()
1538
+ if c in i.getbands():
1539
+ mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
1540
+ mask = torch.from_numpy(mask)
1541
+ if c == 'A':
1542
+ mask = 1. - mask
1543
+ else:
1544
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1545
+ return (mask.unsqueeze(0),)
1546
+
1547
+ @classmethod
1548
+ def IS_CHANGED(s, image, channel):
1549
+ image_path = ldm_patched.utils.path_utils.get_annotated_filepath(image)
1550
+ m = hashlib.sha256()
1551
+ with open(image_path, 'rb') as f:
1552
+ m.update(f.read())
1553
+ return m.digest().hex()
1554
+
1555
+ @classmethod
1556
+ def VALIDATE_INPUTS(s, image):
1557
+ if not ldm_patched.utils.path_utils.exists_annotated_filepath(image):
1558
+ return "Invalid image file: {}".format(image)
1559
+
1560
+ return True
1561
+
1562
+ class ImageScale:
1563
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1564
+ crop_methods = ["disabled", "center"]
1565
+
1566
+ @classmethod
1567
+ def INPUT_TYPES(s):
1568
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1569
+ "width": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1570
+ "height": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1571
+ "crop": (s.crop_methods,)}}
1572
+ RETURN_TYPES = ("IMAGE",)
1573
+ FUNCTION = "upscale"
1574
+
1575
+ CATEGORY = "image/upscaling"
1576
+
1577
+ def upscale(self, image, upscale_method, width, height, crop):
1578
+ if width == 0 and height == 0:
1579
+ s = image
1580
+ else:
1581
+ samples = image.movedim(-1,1)
1582
+
1583
+ if width == 0:
1584
+ width = max(1, round(samples.shape[3] * height / samples.shape[2]))
1585
+ elif height == 0:
1586
+ height = max(1, round(samples.shape[2] * width / samples.shape[3]))
1587
+
1588
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, crop)
1589
+ s = s.movedim(1,-1)
1590
+ return (s,)
1591
+
1592
+ class ImageScaleBy:
1593
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
1594
+
1595
+ @classmethod
1596
+ def INPUT_TYPES(s):
1597
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
1598
+ "scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 8.0, "step": 0.01}),}}
1599
+ RETURN_TYPES = ("IMAGE",)
1600
+ FUNCTION = "upscale"
1601
+
1602
+ CATEGORY = "image/upscaling"
1603
+
1604
+ def upscale(self, image, upscale_method, scale_by):
1605
+ samples = image.movedim(-1,1)
1606
+ width = round(samples.shape[3] * scale_by)
1607
+ height = round(samples.shape[2] * scale_by)
1608
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled")
1609
+ s = s.movedim(1,-1)
1610
+ return (s,)
1611
+
1612
+ class ImageInvert:
1613
+
1614
+ @classmethod
1615
+ def INPUT_TYPES(s):
1616
+ return {"required": { "image": ("IMAGE",)}}
1617
+
1618
+ RETURN_TYPES = ("IMAGE",)
1619
+ FUNCTION = "invert"
1620
+
1621
+ CATEGORY = "image"
1622
+
1623
+ def invert(self, image):
1624
+ s = 1.0 - image
1625
+ return (s,)
1626
+
1627
+ class ImageBatch:
1628
+
1629
+ @classmethod
1630
+ def INPUT_TYPES(s):
1631
+ return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}}
1632
+
1633
+ RETURN_TYPES = ("IMAGE",)
1634
+ FUNCTION = "batch"
1635
+
1636
+ CATEGORY = "image"
1637
+
1638
+ def batch(self, image1, image2):
1639
+ if image1.shape[1:] != image2.shape[1:]:
1640
+ image2 = ldm_patched.modules.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
1641
+ s = torch.cat((image1, image2), dim=0)
1642
+ return (s,)
1643
+
1644
+ class EmptyImage:
1645
+ def __init__(self, device="cpu"):
1646
+ self.device = device
1647
+
1648
+ @classmethod
1649
+ def INPUT_TYPES(s):
1650
+ return {"required": { "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1651
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
1652
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
1653
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
1654
+ }}
1655
+ RETURN_TYPES = ("IMAGE",)
1656
+ FUNCTION = "generate"
1657
+
1658
+ CATEGORY = "image"
1659
+
1660
+ def generate(self, width, height, batch_size=1, color=0):
1661
+ r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
1662
+ g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
1663
+ b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
1664
+ return (torch.cat((r, g, b), dim=-1), )
1665
+
1666
+ class ImagePadForOutpaint:
1667
+
1668
+ @classmethod
1669
+ def INPUT_TYPES(s):
1670
+ return {
1671
+ "required": {
1672
+ "image": ("IMAGE",),
1673
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1674
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1675
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1676
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
1677
+ "feathering": ("INT", {"default": 40, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
1678
+ }
1679
+ }
1680
+
1681
+ RETURN_TYPES = ("IMAGE", "MASK")
1682
+ FUNCTION = "expand_image"
1683
+
1684
+ CATEGORY = "image"
1685
+
1686
+ def expand_image(self, image, left, top, right, bottom, feathering):
1687
+ d1, d2, d3, d4 = image.size()
1688
+
1689
+ new_image = torch.ones(
1690
+ (d1, d2 + top + bottom, d3 + left + right, d4),
1691
+ dtype=torch.float32,
1692
+ ) * 0.5
1693
+
1694
+ new_image[:, top:top + d2, left:left + d3, :] = image
1695
+
1696
+ mask = torch.ones(
1697
+ (d2 + top + bottom, d3 + left + right),
1698
+ dtype=torch.float32,
1699
+ )
1700
+
1701
+ t = torch.zeros(
1702
+ (d2, d3),
1703
+ dtype=torch.float32
1704
+ )
1705
+
1706
+ if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
1707
+
1708
+ for i in range(d2):
1709
+ for j in range(d3):
1710
+ dt = i if top != 0 else d2
1711
+ db = d2 - i if bottom != 0 else d2
1712
+
1713
+ dl = j if left != 0 else d3
1714
+ dr = d3 - j if right != 0 else d3
1715
+
1716
+ d = min(dt, db, dl, dr)
1717
+
1718
+ if d >= feathering:
1719
+ continue
1720
+
1721
+ v = (feathering - d) / feathering
1722
+
1723
+ t[i, j] = v * v
1724
+
1725
+ mask[top:top + d2, left:left + d3] = t
1726
+
1727
+ return (new_image, mask)
1728
+
1729
+
1730
+ NODE_CLASS_MAPPINGS = {
1731
+ "KSampler": KSampler,
1732
+ "CheckpointLoaderSimple": CheckpointLoaderSimple,
1733
+ "CLIPTextEncode": CLIPTextEncode,
1734
+ "CLIPSetLastLayer": CLIPSetLastLayer,
1735
+ "VAEDecode": VAEDecode,
1736
+ "VAEEncode": VAEEncode,
1737
+ "VAEEncodeForInpaint": VAEEncodeForInpaint,
1738
+ "VAELoader": VAELoader,
1739
+ "EmptyLatentImage": EmptyLatentImage,
1740
+ "LatentUpscale": LatentUpscale,
1741
+ "LatentUpscaleBy": LatentUpscaleBy,
1742
+ "LatentFromBatch": LatentFromBatch,
1743
+ "RepeatLatentBatch": RepeatLatentBatch,
1744
+ "SaveImage": SaveImage,
1745
+ "PreviewImage": PreviewImage,
1746
+ "LoadImage": LoadImage,
1747
+ "LoadImageMask": LoadImageMask,
1748
+ "ImageScale": ImageScale,
1749
+ "ImageScaleBy": ImageScaleBy,
1750
+ "ImageInvert": ImageInvert,
1751
+ "ImageBatch": ImageBatch,
1752
+ "ImagePadForOutpaint": ImagePadForOutpaint,
1753
+ "EmptyImage": EmptyImage,
1754
+ "ConditioningAverage": ConditioningAverage ,
1755
+ "ConditioningCombine": ConditioningCombine,
1756
+ "ConditioningConcat": ConditioningConcat,
1757
+ "ConditioningSetArea": ConditioningSetArea,
1758
+ "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
1759
+ "ConditioningSetMask": ConditioningSetMask,
1760
+ "KSamplerAdvanced": KSamplerAdvanced,
1761
+ "SetLatentNoiseMask": SetLatentNoiseMask,
1762
+ "LatentComposite": LatentComposite,
1763
+ "LatentBlend": LatentBlend,
1764
+ "LatentRotate": LatentRotate,
1765
+ "LatentFlip": LatentFlip,
1766
+ "LatentCrop": LatentCrop,
1767
+ "LoraLoader": LoraLoader,
1768
+ "CLIPLoader": CLIPLoader,
1769
+ "UNETLoader": UNETLoader,
1770
+ "DualCLIPLoader": DualCLIPLoader,
1771
+ "CLIPVisionEncode": CLIPVisionEncode,
1772
+ "StyleModelApply": StyleModelApply,
1773
+ "unCLIPConditioning": unCLIPConditioning,
1774
+ "ControlNetApply": ControlNetApply,
1775
+ "ControlNetApplyAdvanced": ControlNetApplyAdvanced,
1776
+ "ControlNetLoader": ControlNetLoader,
1777
+ "DiffControlNetLoader": DiffControlNetLoader,
1778
+ "StyleModelLoader": StyleModelLoader,
1779
+ "CLIPVisionLoader": CLIPVisionLoader,
1780
+ "VAEDecodeTiled": VAEDecodeTiled,
1781
+ "VAEEncodeTiled": VAEEncodeTiled,
1782
+ "unCLIPCheckpointLoader": unCLIPCheckpointLoader,
1783
+ "GLIGENLoader": GLIGENLoader,
1784
+ "GLIGENTextBoxApply": GLIGENTextBoxApply,
1785
+ "InpaintModelConditioning": InpaintModelConditioning,
1786
+
1787
+ "CheckpointLoader": CheckpointLoader,
1788
+ "DiffusersLoader": DiffusersLoader,
1789
+
1790
+ "LoadLatent": LoadLatent,
1791
+ "SaveLatent": SaveLatent,
1792
+
1793
+ "ConditioningZeroOut": ConditioningZeroOut,
1794
+ "ConditioningSetTimestepRange": ConditioningSetTimestepRange,
1795
+ "LoraLoaderModelOnly": LoraLoaderModelOnly,
1796
+ }
1797
+
1798
+ NODE_DISPLAY_NAME_MAPPINGS = {
1799
+ # Sampling
1800
+ "KSampler": "KSampler",
1801
+ "KSamplerAdvanced": "KSampler (Advanced)",
1802
+ # Loaders
1803
+ "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)",
1804
+ "CheckpointLoaderSimple": "Load Checkpoint",
1805
+ "VAELoader": "Load VAE",
1806
+ "LoraLoader": "Load LoRA",
1807
+ "CLIPLoader": "Load CLIP",
1808
+ "ControlNetLoader": "Load ControlNet Model",
1809
+ "DiffControlNetLoader": "Load ControlNet Model (diff)",
1810
+ "StyleModelLoader": "Load Style Model",
1811
+ "CLIPVisionLoader": "Load CLIP Vision",
1812
+ "UpscaleModelLoader": "Load Upscale Model",
1813
+ # Conditioning
1814
+ "CLIPVisionEncode": "CLIP Vision Encode",
1815
+ "StyleModelApply": "Apply Style Model",
1816
+ "CLIPTextEncode": "CLIP Text Encode (Prompt)",
1817
+ "CLIPSetLastLayer": "CLIP Set Last Layer",
1818
+ "ConditioningCombine": "Conditioning (Combine)",
1819
+ "ConditioningAverage ": "Conditioning (Average)",
1820
+ "ConditioningConcat": "Conditioning (Concat)",
1821
+ "ConditioningSetArea": "Conditioning (Set Area)",
1822
+ "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
1823
+ "ConditioningSetMask": "Conditioning (Set Mask)",
1824
+ "ControlNetApply": "Apply ControlNet",
1825
+ "ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
1826
+ # Latent
1827
+ "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
1828
+ "SetLatentNoiseMask": "Set Latent Noise Mask",
1829
+ "VAEDecode": "VAE Decode",
1830
+ "VAEEncode": "VAE Encode",
1831
+ "LatentRotate": "Rotate Latent",
1832
+ "LatentFlip": "Flip Latent",
1833
+ "LatentCrop": "Crop Latent",
1834
+ "EmptyLatentImage": "Empty Latent Image",
1835
+ "LatentUpscale": "Upscale Latent",
1836
+ "LatentUpscaleBy": "Upscale Latent By",
1837
+ "LatentComposite": "Latent Composite",
1838
+ "LatentBlend": "Latent Blend",
1839
+ "LatentFromBatch" : "Latent From Batch",
1840
+ "RepeatLatentBatch": "Repeat Latent Batch",
1841
+ # Image
1842
+ "SaveImage": "Save Image",
1843
+ "PreviewImage": "Preview Image",
1844
+ "LoadImage": "Load Image",
1845
+ "LoadImageMask": "Load Image (as Mask)",
1846
+ "ImageScale": "Upscale Image",
1847
+ "ImageScaleBy": "Upscale Image By",
1848
+ "ImageUpscaleWithModel": "Upscale Image (using Model)",
1849
+ "ImageInvert": "Invert Image",
1850
+ "ImagePadForOutpaint": "Pad Image for Outpainting",
1851
+ "ImageBatch": "Batch Images",
1852
+ # _for_testing
1853
+ "VAEDecodeTiled": "VAE Decode (Tiled)",
1854
+ "VAEEncodeTiled": "VAE Encode (Tiled)",
1855
+ }
1856
+
1857
+ EXTENSION_WEB_DIRS = {}
1858
+
1859
+ def load_custom_node(module_path, ignore=set()):
1860
+ module_name = os.path.basename(module_path)
1861
+ if os.path.isfile(module_path):
1862
+ sp = os.path.splitext(module_path)
1863
+ module_name = sp[0]
1864
+ try:
1865
+ if os.path.isfile(module_path):
1866
+ module_spec = importlib.util.spec_from_file_location(module_name, module_path)
1867
+ module_dir = os.path.split(module_path)[0]
1868
+ else:
1869
+ module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
1870
+ module_dir = module_path
1871
+
1872
+ module = importlib.util.module_from_spec(module_spec)
1873
+ sys.modules[module_name] = module
1874
+ module_spec.loader.exec_module(module)
1875
+
1876
+ if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
1877
+ web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
1878
+ if os.path.isdir(web_dir):
1879
+ EXTENSION_WEB_DIRS[module_name] = web_dir
1880
+
1881
+ if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
1882
+ for name in module.NODE_CLASS_MAPPINGS:
1883
+ if name not in ignore:
1884
+ NODE_CLASS_MAPPINGS[name] = module.NODE_CLASS_MAPPINGS[name]
1885
+ if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS") and getattr(module, "NODE_DISPLAY_NAME_MAPPINGS") is not None:
1886
+ NODE_DISPLAY_NAME_MAPPINGS.update(module.NODE_DISPLAY_NAME_MAPPINGS)
1887
+ return True
1888
+ else:
1889
+ print(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS.")
1890
+ return False
1891
+ except Exception as e:
1892
+ print(traceback.format_exc())
1893
+ print(f"Cannot import {module_path} module for custom nodes:", e)
1894
+ return False
1895
+
1896
+ def load_custom_nodes():
1897
+ base_node_names = set(NODE_CLASS_MAPPINGS.keys())
1898
+ node_paths = ldm_patched.utils.path_utils.get_folder_paths("custom_nodes")
1899
+ node_import_times = []
1900
+ for custom_node_path in node_paths:
1901
+ possible_modules = os.listdir(os.path.realpath(custom_node_path))
1902
+ if "__pycache__" in possible_modules:
1903
+ possible_modules.remove("__pycache__")
1904
+
1905
+ for possible_module in possible_modules:
1906
+ module_path = os.path.join(custom_node_path, possible_module)
1907
+ if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
1908
+ if module_path.endswith(".disabled"): continue
1909
+ time_before = time.perf_counter()
1910
+ success = load_custom_node(module_path, base_node_names)
1911
+ node_import_times.append((time.perf_counter() - time_before, module_path, success))
1912
+
1913
+ if len(node_import_times) > 0:
1914
+ print("\nImport times for custom nodes:")
1915
+ for n in sorted(node_import_times):
1916
+ if n[2]:
1917
+ import_message = ""
1918
+ else:
1919
+ import_message = " (IMPORT FAILED)"
1920
+ print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
1921
+ print()
1922
+
1923
+ def init_custom_nodes():
1924
+ extras_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ldm_patched_extras")
1925
+ extras_files = [
1926
+ "nodes_latent.py",
1927
+ "nodes_hypernetwork.py",
1928
+ "nodes_upscale_model.py",
1929
+ "nodes_post_processing.py",
1930
+ "nodes_mask.py",
1931
+ "nodes_compositing.py",
1932
+ "nodes_rebatch.py",
1933
+ "nodes_model_merging.py",
1934
+ "nodes_tomesd.py",
1935
+ "nodes_clip_sdxl.py",
1936
+ "nodes_canny.py",
1937
+ "nodes_freelunch.py",
1938
+ "nodes_custom_sampler.py",
1939
+ "nodes_hypertile.py",
1940
+ "nodes_model_advanced.py",
1941
+ "nodes_model_downscale.py",
1942
+ "nodes_images.py",
1943
+ "nodes_video_model.py",
1944
+ "nodes_sag.py",
1945
+ "nodes_perpneg.py",
1946
+ "nodes_stable3d.py",
1947
+ "nodes_sdupscale.py",
1948
+ "nodes_photomaker.py",
1949
+ ]
1950
+
1951
+ for node_file in extras_files:
1952
+ load_custom_node(os.path.join(extras_dir, node_file))
1953
+
1954
+ load_custom_nodes()
ldm_patched/contrib/external_align_your_steps.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
4
+ import numpy as np
5
+ import torch
6
+
7
+ def loglinear_interp(t_steps, num_steps):
8
+ """
9
+ Performs log-linear interpolation of a given array of decreasing numbers.
10
+ """
11
+ xs = np.linspace(0, 1, len(t_steps))
12
+ ys = np.log(t_steps[::-1])
13
+
14
+ new_xs = np.linspace(0, 1, num_steps)
15
+ new_ys = np.interp(new_xs, xs, ys)
16
+
17
+ interped_ys = np.exp(new_ys)[::-1].copy()
18
+ return interped_ys
19
+
20
+ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582],
21
+ "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
22
+ "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
23
+
24
+ class AlignYourStepsScheduler:
25
+ @classmethod
26
+ def INPUT_TYPES(s):
27
+ return {"required":
28
+ {"model_type": (["SD1", "SDXL", "SVD"], ),
29
+ "steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
30
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
31
+ }
32
+ }
33
+ RETURN_TYPES = ("SIGMAS",)
34
+ CATEGORY = "sampling/custom_sampling/schedulers"
35
+
36
+ FUNCTION = "get_sigmas"
37
+
38
+ def get_sigmas(self, model_type, steps, denoise):
39
+ total_steps = steps
40
+ if denoise < 1.0:
41
+ if denoise <= 0.0:
42
+ return (torch.FloatTensor([]),)
43
+ total_steps = round(steps * denoise)
44
+
45
+ sigmas = NOISE_LEVELS[model_type][:]
46
+ if (steps + 1) != len(sigmas):
47
+ sigmas = loglinear_interp(sigmas, steps + 1)
48
+
49
+ sigmas = sigmas[-(total_steps + 1):]
50
+ sigmas[-1] = 0
51
+ return (torch.FloatTensor(sigmas), )
52
+
53
+ NODE_CLASS_MAPPINGS = {
54
+ "AlignYourStepsScheduler": AlignYourStepsScheduler,
55
+ }
ldm_patched/contrib/external_canny.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #From https://github.com/kornia/kornia
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import ldm_patched.modules.model_management
9
+
10
+ def get_canny_nms_kernel(device=None, dtype=None):
11
+ """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
12
+ return torch.tensor(
13
+ [
14
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]]],
15
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]],
16
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]]],
17
+ [[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]],
18
+ [[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
19
+ [[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
20
+ [[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
21
+ [[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]]],
22
+ ],
23
+ device=device,
24
+ dtype=dtype,
25
+ )
26
+
27
+
28
+ def get_hysteresis_kernel(device=None, dtype=None):
29
+ """Utility function that returns the 3x3 kernels for the Canny hysteresis."""
30
+ return torch.tensor(
31
+ [
32
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]]],
33
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]],
34
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]],
35
+ [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]],
36
+ [[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
37
+ [[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
38
+ [[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
39
+ [[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]],
40
+ ],
41
+ device=device,
42
+ dtype=dtype,
43
+ )
44
+
45
+ def gaussian_blur_2d(img, kernel_size, sigma):
46
+ ksize_half = (kernel_size - 1) * 0.5
47
+
48
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
49
+
50
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
51
+
52
+ x_kernel = pdf / pdf.sum()
53
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
54
+
55
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
56
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
57
+
58
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
59
+
60
+ img = torch.nn.functional.pad(img, padding, mode="reflect")
61
+ img = torch.nn.functional.conv2d(img, kernel2d, groups=img.shape[-3])
62
+
63
+ return img
64
+
65
+ def get_sobel_kernel2d(device=None, dtype=None):
66
+ kernel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]], device=device, dtype=dtype)
67
+ kernel_y = kernel_x.transpose(0, 1)
68
+ return torch.stack([kernel_x, kernel_y])
69
+
70
+ def spatial_gradient(input, normalized: bool = True):
71
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
72
+ .. image:: _static/img/spatial_gradient.png
73
+ Args:
74
+ input: input image tensor with shape :math:`(B, C, H, W)`.
75
+ mode: derivatives modality, can be: `sobel` or `diff`.
76
+ order: the order of the derivatives.
77
+ normalized: whether the output is normalized.
78
+ Return:
79
+ the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
80
+ .. note::
81
+ See a working example `here <https://kornia.readthedocs.io/en/latest/
82
+ filtering_edges.html>`__.
83
+ Examples:
84
+ >>> input = torch.rand(1, 3, 4, 4)
85
+ >>> output = spatial_gradient(input) # 1x3x2x4x4
86
+ >>> output.shape
87
+ torch.Size([1, 3, 2, 4, 4])
88
+ """
89
+ # KORNIA_CHECK_IS_TENSOR(input)
90
+ # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
91
+
92
+ # allocate kernel
93
+ kernel = get_sobel_kernel2d(device=input.device, dtype=input.dtype)
94
+ if normalized:
95
+ kernel = normalize_kernel2d(kernel)
96
+
97
+ # prepare kernel
98
+ b, c, h, w = input.shape
99
+ tmp_kernel = kernel[:, None, ...]
100
+
101
+ # Pad with "replicate for spatial dims, but with zeros for channel
102
+ spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
103
+ out_channels: int = 2
104
+ padded_inp = torch.nn.functional.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')
105
+ out = F.conv2d(padded_inp, tmp_kernel, groups=1, padding=0, stride=1)
106
+ return out.reshape(b, c, out_channels, h, w)
107
+
108
+ def rgb_to_grayscale(image, rgb_weights = None):
109
+ r"""Convert a RGB image to grayscale version of image.
110
+
111
+ .. image:: _static/img/rgb_to_grayscale.png
112
+
113
+ The image data is assumed to be in the range of (0, 1).
114
+
115
+ Args:
116
+ image: RGB image to be converted to grayscale with shape :math:`(*,3,H,W)`.
117
+ rgb_weights: Weights that will be applied on each channel (RGB).
118
+ The sum of the weights should add up to one.
119
+ Returns:
120
+ grayscale version of the image with shape :math:`(*,1,H,W)`.
121
+
122
+ .. note::
123
+ See a working example `here <https://kornia.readthedocs.io/en/latest/
124
+ color_conversions.html>`__.
125
+
126
+ Example:
127
+ >>> input = torch.rand(2, 3, 4, 5)
128
+ >>> gray = rgb_to_grayscale(input) # 2x1x4x5
129
+ """
130
+
131
+ if len(image.shape) < 3 or image.shape[-3] != 3:
132
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
133
+
134
+ if rgb_weights is None:
135
+ # 8 bit images
136
+ if image.dtype == torch.uint8:
137
+ rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
138
+ # floating point images
139
+ elif image.dtype in (torch.float16, torch.float32, torch.float64):
140
+ rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
141
+ else:
142
+ raise TypeError(f"Unknown data type: {image.dtype}")
143
+ else:
144
+ # is tensor that we make sure is in the same device/dtype
145
+ rgb_weights = rgb_weights.to(image)
146
+
147
+ # unpack the color image channels with RGB order
148
+ r: Tensor = image[..., 0:1, :, :]
149
+ g: Tensor = image[..., 1:2, :, :]
150
+ b: Tensor = image[..., 2:3, :, :]
151
+
152
+ w_r, w_g, w_b = rgb_weights.unbind()
153
+ return w_r * r + w_g * g + w_b * b
154
+
155
+ def canny(
156
+ input,
157
+ low_threshold = 0.1,
158
+ high_threshold = 0.2,
159
+ kernel_size = 5,
160
+ sigma = 1,
161
+ hysteresis = True,
162
+ eps = 1e-6,
163
+ ):
164
+ r"""Find edges of the input image and filters them using the Canny algorithm.
165
+ .. image:: _static/img/canny.png
166
+ Args:
167
+ input: input image tensor with shape :math:`(B,C,H,W)`.
168
+ low_threshold: lower threshold for the hysteresis procedure.
169
+ high_threshold: upper threshold for the hysteresis procedure.
170
+ kernel_size: the size of the kernel for the gaussian blur.
171
+ sigma: the standard deviation of the kernel for the gaussian blur.
172
+ hysteresis: if True, applies the hysteresis edge tracking.
173
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
174
+ eps: regularization number to avoid NaN during backprop.
175
+ Returns:
176
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
177
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
178
+ .. note::
179
+ See a working example `here <https://kornia.readthedocs.io/en/latest/
180
+ canny.html>`__.
181
+ Example:
182
+ >>> input = torch.rand(5, 3, 4, 4)
183
+ >>> magnitude, edges = canny(input) # 5x3x4x4
184
+ >>> magnitude.shape
185
+ torch.Size([5, 1, 4, 4])
186
+ >>> edges.shape
187
+ torch.Size([5, 1, 4, 4])
188
+ """
189
+ # KORNIA_CHECK_IS_TENSOR(input)
190
+ # KORNIA_CHECK_SHAPE(input, ['B', 'C', 'H', 'W'])
191
+ # KORNIA_CHECK(
192
+ # low_threshold <= high_threshold,
193
+ # "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: "
194
+ # f"{low_threshold}>{high_threshold}",
195
+ # )
196
+ # KORNIA_CHECK(0 < low_threshold < 1, f'Invalid low threshold. Should be in range (0, 1). Got: {low_threshold}')
197
+ # KORNIA_CHECK(0 < high_threshold < 1, f'Invalid high threshold. Should be in range (0, 1). Got: {high_threshold}')
198
+
199
+ device = input.device
200
+ dtype = input.dtype
201
+
202
+ # To Grayscale
203
+ if input.shape[1] == 3:
204
+ input = rgb_to_grayscale(input)
205
+
206
+ # Gaussian filter
207
+ blurred: Tensor = gaussian_blur_2d(input, kernel_size, sigma)
208
+
209
+ # Compute the gradients
210
+ gradients: Tensor = spatial_gradient(blurred, normalized=False)
211
+
212
+ # Unpack the edges
213
+ gx: Tensor = gradients[:, :, 0]
214
+ gy: Tensor = gradients[:, :, 1]
215
+
216
+ # Compute gradient magnitude and angle
217
+ magnitude: Tensor = torch.sqrt(gx * gx + gy * gy + eps)
218
+ angle: Tensor = torch.atan2(gy, gx)
219
+
220
+ # Radians to Degrees
221
+ angle = 180.0 * angle / math.pi
222
+
223
+ # Round angle to the nearest 45 degree
224
+ angle = torch.round(angle / 45) * 45
225
+
226
+ # Non-maximal suppression
227
+ nms_kernels: Tensor = get_canny_nms_kernel(device, dtype)
228
+ nms_magnitude: Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
229
+
230
+ # Get the indices for both directions
231
+ positive_idx: Tensor = (angle / 45) % 8
232
+ positive_idx = positive_idx.long()
233
+
234
+ negative_idx: Tensor = ((angle / 45) + 4) % 8
235
+ negative_idx = negative_idx.long()
236
+
237
+ # Apply the non-maximum suppression to the different directions
238
+ channel_select_filtered_positive: Tensor = torch.gather(nms_magnitude, 1, positive_idx)
239
+ channel_select_filtered_negative: Tensor = torch.gather(nms_magnitude, 1, negative_idx)
240
+
241
+ channel_select_filtered: Tensor = torch.stack(
242
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
243
+ )
244
+
245
+ is_max: Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
246
+
247
+ magnitude = magnitude * is_max
248
+
249
+ # Threshold
250
+ edges: Tensor = F.threshold(magnitude, low_threshold, 0.0)
251
+
252
+ low: Tensor = magnitude > low_threshold
253
+ high: Tensor = magnitude > high_threshold
254
+
255
+ edges = low * 0.5 + high * 0.5
256
+ edges = edges.to(dtype)
257
+
258
+ # Hysteresis
259
+ if hysteresis:
260
+ edges_old: Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
261
+ hysteresis_kernels: Tensor = get_hysteresis_kernel(device, dtype)
262
+
263
+ while ((edges_old - edges).abs() != 0).any():
264
+ weak: Tensor = (edges == 0.5).float()
265
+ strong: Tensor = (edges == 1).float()
266
+
267
+ hysteresis_magnitude: Tensor = F.conv2d(
268
+ edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
269
+ )
270
+ hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
271
+ hysteresis_magnitude = hysteresis_magnitude * weak + strong
272
+
273
+ edges_old = edges.clone()
274
+ edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
275
+
276
+ edges = hysteresis_magnitude
277
+
278
+ return magnitude, edges
279
+
280
+
281
+ class Canny:
282
+ @classmethod
283
+ def INPUT_TYPES(s):
284
+ return {"required": {"image": ("IMAGE",),
285
+ "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}),
286
+ "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01})
287
+ }}
288
+
289
+ RETURN_TYPES = ("IMAGE",)
290
+ FUNCTION = "detect_edge"
291
+
292
+ CATEGORY = "image/preprocessors"
293
+
294
+ def detect_edge(self, image, low_threshold, high_threshold):
295
+ output = canny(image.to(ldm_patched.modules.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
296
+ img_out = output[1].to(ldm_patched.modules.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
297
+ return (img_out,)
298
+
299
+ NODE_CLASS_MAPPINGS = {
300
+ "Canny": Canny,
301
+ }
ldm_patched/contrib/external_clip_sdxl.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ from ldm_patched.contrib.external import MAX_RESOLUTION
5
+
6
+ class CLIPTextEncodeSDXLRefiner:
7
+ @classmethod
8
+ def INPUT_TYPES(s):
9
+ return {"required": {
10
+ "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
11
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
12
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
13
+ "text": ("STRING", {"multiline": True}), "clip": ("CLIP", ),
14
+ }}
15
+ RETURN_TYPES = ("CONDITIONING",)
16
+ FUNCTION = "encode"
17
+
18
+ CATEGORY = "advanced/conditioning"
19
+
20
+ def encode(self, clip, ascore, width, height, text):
21
+ tokens = clip.tokenize(text)
22
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
23
+ return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], )
24
+
25
+ class CLIPTextEncodeSDXL:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": {
29
+ "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
30
+ "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
31
+ "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
32
+ "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
33
+ "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
34
+ "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
35
+ "text_g": ("STRING", {"multiline": True, "default": "CLIP_G"}), "clip": ("CLIP", ),
36
+ "text_l": ("STRING", {"multiline": True, "default": "CLIP_L"}), "clip": ("CLIP", ),
37
+ }}
38
+ RETURN_TYPES = ("CONDITIONING",)
39
+ FUNCTION = "encode"
40
+
41
+ CATEGORY = "advanced/conditioning"
42
+
43
+ def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l):
44
+ tokens = clip.tokenize(text_g)
45
+ tokens["l"] = clip.tokenize(text_l)["l"]
46
+ if len(tokens["l"]) != len(tokens["g"]):
47
+ empty = clip.tokenize("")
48
+ while len(tokens["l"]) < len(tokens["g"]):
49
+ tokens["l"] += empty["l"]
50
+ while len(tokens["l"]) > len(tokens["g"]):
51
+ tokens["g"] += empty["g"]
52
+ cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
53
+ return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], )
54
+
55
+ NODE_CLASS_MAPPINGS = {
56
+ "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner,
57
+ "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL,
58
+ }
ldm_patched/contrib/external_compositing.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ import ldm_patched.modules.utils
6
+ from enum import Enum
7
+
8
+ def resize_mask(mask, shape):
9
+ return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
10
+
11
+ class PorterDuffMode(Enum):
12
+ ADD = 0
13
+ CLEAR = 1
14
+ DARKEN = 2
15
+ DST = 3
16
+ DST_ATOP = 4
17
+ DST_IN = 5
18
+ DST_OUT = 6
19
+ DST_OVER = 7
20
+ LIGHTEN = 8
21
+ MULTIPLY = 9
22
+ OVERLAY = 10
23
+ SCREEN = 11
24
+ SRC = 12
25
+ SRC_ATOP = 13
26
+ SRC_IN = 14
27
+ SRC_OUT = 15
28
+ SRC_OVER = 16
29
+ XOR = 17
30
+
31
+
32
+ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
33
+ if mode == PorterDuffMode.ADD:
34
+ out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
35
+ out_image = torch.clamp(src_image + dst_image, 0, 1)
36
+ elif mode == PorterDuffMode.CLEAR:
37
+ out_alpha = torch.zeros_like(dst_alpha)
38
+ out_image = torch.zeros_like(dst_image)
39
+ elif mode == PorterDuffMode.DARKEN:
40
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
41
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
42
+ elif mode == PorterDuffMode.DST:
43
+ out_alpha = dst_alpha
44
+ out_image = dst_image
45
+ elif mode == PorterDuffMode.DST_ATOP:
46
+ out_alpha = src_alpha
47
+ out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
48
+ elif mode == PorterDuffMode.DST_IN:
49
+ out_alpha = src_alpha * dst_alpha
50
+ out_image = dst_image * src_alpha
51
+ elif mode == PorterDuffMode.DST_OUT:
52
+ out_alpha = (1 - src_alpha) * dst_alpha
53
+ out_image = (1 - src_alpha) * dst_image
54
+ elif mode == PorterDuffMode.DST_OVER:
55
+ out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
56
+ out_image = dst_image + (1 - dst_alpha) * src_image
57
+ elif mode == PorterDuffMode.LIGHTEN:
58
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
59
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
60
+ elif mode == PorterDuffMode.MULTIPLY:
61
+ out_alpha = src_alpha * dst_alpha
62
+ out_image = src_image * dst_image
63
+ elif mode == PorterDuffMode.OVERLAY:
64
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
65
+ out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
66
+ src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
67
+ elif mode == PorterDuffMode.SCREEN:
68
+ out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
69
+ out_image = src_image + dst_image - src_image * dst_image
70
+ elif mode == PorterDuffMode.SRC:
71
+ out_alpha = src_alpha
72
+ out_image = src_image
73
+ elif mode == PorterDuffMode.SRC_ATOP:
74
+ out_alpha = dst_alpha
75
+ out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
76
+ elif mode == PorterDuffMode.SRC_IN:
77
+ out_alpha = src_alpha * dst_alpha
78
+ out_image = src_image * dst_alpha
79
+ elif mode == PorterDuffMode.SRC_OUT:
80
+ out_alpha = (1 - dst_alpha) * src_alpha
81
+ out_image = (1 - dst_alpha) * src_image
82
+ elif mode == PorterDuffMode.SRC_OVER:
83
+ out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
84
+ out_image = src_image + (1 - src_alpha) * dst_image
85
+ elif mode == PorterDuffMode.XOR:
86
+ out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
87
+ out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
88
+ else:
89
+ out_alpha = None
90
+ out_image = None
91
+ return out_image, out_alpha
92
+
93
+
94
+ class PorterDuffImageComposite:
95
+ @classmethod
96
+ def INPUT_TYPES(s):
97
+ return {
98
+ "required": {
99
+ "source": ("IMAGE",),
100
+ "source_alpha": ("MASK",),
101
+ "destination": ("IMAGE",),
102
+ "destination_alpha": ("MASK",),
103
+ "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
104
+ },
105
+ }
106
+
107
+ RETURN_TYPES = ("IMAGE", "MASK")
108
+ FUNCTION = "composite"
109
+ CATEGORY = "mask/compositing"
110
+
111
+ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
112
+ batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
113
+ out_images = []
114
+ out_alphas = []
115
+
116
+ for i in range(batch_size):
117
+ src_image = source[i]
118
+ dst_image = destination[i]
119
+
120
+ assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
121
+
122
+ src_alpha = source_alpha[i].unsqueeze(2)
123
+ dst_alpha = destination_alpha[i].unsqueeze(2)
124
+
125
+ if dst_alpha.shape[:2] != dst_image.shape[:2]:
126
+ upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
127
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
128
+ dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
129
+ if src_image.shape != dst_image.shape:
130
+ upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
131
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
132
+ src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
133
+ if src_alpha.shape != dst_alpha.shape:
134
+ upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
135
+ upscale_output = ldm_patched.modules.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
136
+ src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
137
+
138
+ out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
139
+
140
+ out_images.append(out_image)
141
+ out_alphas.append(out_alpha.squeeze(2))
142
+
143
+ result = (torch.stack(out_images), torch.stack(out_alphas))
144
+ return result
145
+
146
+
147
+ class SplitImageWithAlpha:
148
+ @classmethod
149
+ def INPUT_TYPES(s):
150
+ return {
151
+ "required": {
152
+ "image": ("IMAGE",),
153
+ }
154
+ }
155
+
156
+ CATEGORY = "mask/compositing"
157
+ RETURN_TYPES = ("IMAGE", "MASK")
158
+ FUNCTION = "split_image_with_alpha"
159
+
160
+ def split_image_with_alpha(self, image: torch.Tensor):
161
+ out_images = [i[:,:,:3] for i in image]
162
+ out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
163
+ result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
164
+ return result
165
+
166
+
167
+ class JoinImageWithAlpha:
168
+ @classmethod
169
+ def INPUT_TYPES(s):
170
+ return {
171
+ "required": {
172
+ "image": ("IMAGE",),
173
+ "alpha": ("MASK",),
174
+ }
175
+ }
176
+
177
+ CATEGORY = "mask/compositing"
178
+ RETURN_TYPES = ("IMAGE",)
179
+ FUNCTION = "join_image_with_alpha"
180
+
181
+ def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
182
+ batch_size = min(len(image), len(alpha))
183
+ out_images = []
184
+
185
+ alpha = 1.0 - resize_mask(alpha, image.shape[1:])
186
+ for i in range(batch_size):
187
+ out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
188
+
189
+ result = (torch.stack(out_images),)
190
+ return result
191
+
192
+
193
+ NODE_CLASS_MAPPINGS = {
194
+ "PorterDuffImageComposite": PorterDuffImageComposite,
195
+ "SplitImageWithAlpha": SplitImageWithAlpha,
196
+ "JoinImageWithAlpha": JoinImageWithAlpha,
197
+ }
198
+
199
+
200
+ NODE_DISPLAY_NAME_MAPPINGS = {
201
+ "PorterDuffImageComposite": "Porter-Duff Image Composite",
202
+ "SplitImageWithAlpha": "Split Image with Alpha",
203
+ "JoinImageWithAlpha": "Join Image with Alpha",
204
+ }
ldm_patched/contrib/external_custom_sampler.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.modules.samplers
4
+ import ldm_patched.modules.sample
5
+ from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
6
+ import ldm_patched.utils.latent_visualization
7
+ import torch
8
+ import ldm_patched.modules.utils
9
+
10
+
11
+ class BasicScheduler:
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {"required":
15
+ {"model": ("MODEL",),
16
+ "scheduler": (ldm_patched.modules.samplers.SCHEDULER_NAMES, ),
17
+ "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
18
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
19
+ }
20
+ }
21
+ RETURN_TYPES = ("SIGMAS",)
22
+ CATEGORY = "sampling/custom_sampling/schedulers"
23
+
24
+ FUNCTION = "get_sigmas"
25
+
26
+ def get_sigmas(self, model, scheduler, steps, denoise):
27
+ total_steps = steps
28
+ if denoise < 1.0:
29
+ total_steps = int(steps/denoise)
30
+
31
+ ldm_patched.modules.model_management.load_models_gpu([model])
32
+ sigmas = ldm_patched.modules.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
33
+ sigmas = sigmas[-(steps + 1):]
34
+ return (sigmas, )
35
+
36
+
37
+ class KarrasScheduler:
38
+ @classmethod
39
+ def INPUT_TYPES(s):
40
+ return {"required":
41
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
42
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
43
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
44
+ "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
45
+ }
46
+ }
47
+ RETURN_TYPES = ("SIGMAS",)
48
+ CATEGORY = "sampling/custom_sampling/schedulers"
49
+
50
+ FUNCTION = "get_sigmas"
51
+
52
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
53
+ sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
54
+ return (sigmas, )
55
+
56
+ class ExponentialScheduler:
57
+ @classmethod
58
+ def INPUT_TYPES(s):
59
+ return {"required":
60
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
61
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
62
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
63
+ }
64
+ }
65
+ RETURN_TYPES = ("SIGMAS",)
66
+ CATEGORY = "sampling/custom_sampling/schedulers"
67
+
68
+ FUNCTION = "get_sigmas"
69
+
70
+ def get_sigmas(self, steps, sigma_max, sigma_min):
71
+ sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max)
72
+ return (sigmas, )
73
+
74
+ class PolyexponentialScheduler:
75
+ @classmethod
76
+ def INPUT_TYPES(s):
77
+ return {"required":
78
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
79
+ "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
80
+ "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
81
+ "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
82
+ }
83
+ }
84
+ RETURN_TYPES = ("SIGMAS",)
85
+ CATEGORY = "sampling/custom_sampling/schedulers"
86
+
87
+ FUNCTION = "get_sigmas"
88
+
89
+ def get_sigmas(self, steps, sigma_max, sigma_min, rho):
90
+ sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
91
+ return (sigmas, )
92
+
93
+ class SDTurboScheduler:
94
+ @classmethod
95
+ def INPUT_TYPES(s):
96
+ return {"required":
97
+ {"model": ("MODEL",),
98
+ "steps": ("INT", {"default": 1, "min": 1, "max": 10}),
99
+ "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
100
+ }
101
+ }
102
+ RETURN_TYPES = ("SIGMAS",)
103
+ CATEGORY = "sampling/custom_sampling/schedulers"
104
+
105
+ FUNCTION = "get_sigmas"
106
+
107
+ def get_sigmas(self, model, steps, denoise):
108
+ start_step = 10 - int(10 * denoise)
109
+ timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
110
+ sigmas = model.model_sampling.sigma(timesteps)
111
+ sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
112
+ return (sigmas, )
113
+
114
+ class VPScheduler:
115
+ @classmethod
116
+ def INPUT_TYPES(s):
117
+ return {"required":
118
+ {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
119
+ "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}), #TODO: fix default values
120
+ "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1000.0, "step":0.01, "round": False}),
121
+ "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}),
122
+ }
123
+ }
124
+ RETURN_TYPES = ("SIGMAS",)
125
+ CATEGORY = "sampling/custom_sampling/schedulers"
126
+
127
+ FUNCTION = "get_sigmas"
128
+
129
+ def get_sigmas(self, steps, beta_d, beta_min, eps_s):
130
+ sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s)
131
+ return (sigmas, )
132
+
133
+ class SplitSigmas:
134
+ @classmethod
135
+ def INPUT_TYPES(s):
136
+ return {"required":
137
+ {"sigmas": ("SIGMAS", ),
138
+ "step": ("INT", {"default": 0, "min": 0, "max": 10000}),
139
+ }
140
+ }
141
+ RETURN_TYPES = ("SIGMAS","SIGMAS")
142
+ CATEGORY = "sampling/custom_sampling/sigmas"
143
+
144
+ FUNCTION = "get_sigmas"
145
+
146
+ def get_sigmas(self, sigmas, step):
147
+ sigmas1 = sigmas[:step + 1]
148
+ sigmas2 = sigmas[step:]
149
+ return (sigmas1, sigmas2)
150
+
151
+ class FlipSigmas:
152
+ @classmethod
153
+ def INPUT_TYPES(s):
154
+ return {"required":
155
+ {"sigmas": ("SIGMAS", ),
156
+ }
157
+ }
158
+ RETURN_TYPES = ("SIGMAS",)
159
+ CATEGORY = "sampling/custom_sampling/sigmas"
160
+
161
+ FUNCTION = "get_sigmas"
162
+
163
+ def get_sigmas(self, sigmas):
164
+ sigmas = sigmas.flip(0)
165
+ if sigmas[0] == 0:
166
+ sigmas[0] = 0.0001
167
+ return (sigmas,)
168
+
169
+ class KSamplerSelect:
170
+ @classmethod
171
+ def INPUT_TYPES(s):
172
+ return {"required":
173
+ {"sampler_name": (ldm_patched.modules.samplers.SAMPLER_NAMES, ),
174
+ }
175
+ }
176
+ RETURN_TYPES = ("SAMPLER",)
177
+ CATEGORY = "sampling/custom_sampling/samplers"
178
+
179
+ FUNCTION = "get_sampler"
180
+
181
+ def get_sampler(self, sampler_name):
182
+ sampler = ldm_patched.modules.samplers.sampler_object(sampler_name)
183
+ return (sampler, )
184
+
185
+ class SamplerDPMPP_2M_SDE:
186
+ @classmethod
187
+ def INPUT_TYPES(s):
188
+ return {"required":
189
+ {"solver_type": (['midpoint', 'heun'], ),
190
+ "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
191
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
192
+ "noise_device": (['gpu', 'cpu'], ),
193
+ }
194
+ }
195
+ RETURN_TYPES = ("SAMPLER",)
196
+ CATEGORY = "sampling/custom_sampling/samplers"
197
+
198
+ FUNCTION = "get_sampler"
199
+
200
+ def get_sampler(self, solver_type, eta, s_noise, noise_device):
201
+ if noise_device == 'cpu':
202
+ sampler_name = "dpmpp_2m_sde"
203
+ else:
204
+ sampler_name = "dpmpp_2m_sde_gpu"
205
+ sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})
206
+ return (sampler, )
207
+
208
+
209
+ class SamplerDPMPP_SDE:
210
+ @classmethod
211
+ def INPUT_TYPES(s):
212
+ return {"required":
213
+ {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
214
+ "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
215
+ "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
216
+ "noise_device": (['gpu', 'cpu'], ),
217
+ }
218
+ }
219
+ RETURN_TYPES = ("SAMPLER",)
220
+ CATEGORY = "sampling/custom_sampling/samplers"
221
+
222
+ FUNCTION = "get_sampler"
223
+
224
+ def get_sampler(self, eta, s_noise, r, noise_device):
225
+ if noise_device == 'cpu':
226
+ sampler_name = "dpmpp_sde"
227
+ else:
228
+ sampler_name = "dpmpp_sde_gpu"
229
+ sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
230
+ return (sampler, )
231
+
232
+
233
+ class SamplerTCD:
234
+ @classmethod
235
+ def INPUT_TYPES(s):
236
+ return {
237
+ "required": {
238
+ "eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
239
+ }
240
+ }
241
+ RETURN_TYPES = ("SAMPLER",)
242
+ CATEGORY = "sampling/custom_sampling/samplers"
243
+
244
+ FUNCTION = "get_sampler"
245
+
246
+ def get_sampler(self, eta=0.3):
247
+ sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta})
248
+ return (sampler, )
249
+
250
+
251
+ class SamplerCustom:
252
+ @classmethod
253
+ def INPUT_TYPES(s):
254
+ return {"required":
255
+ {"model": ("MODEL",),
256
+ "add_noise": ("BOOLEAN", {"default": True}),
257
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
258
+ "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
259
+ "positive": ("CONDITIONING", ),
260
+ "negative": ("CONDITIONING", ),
261
+ "sampler": ("SAMPLER", ),
262
+ "sigmas": ("SIGMAS", ),
263
+ "latent_image": ("LATENT", ),
264
+ }
265
+ }
266
+
267
+ RETURN_TYPES = ("LATENT","LATENT")
268
+ RETURN_NAMES = ("output", "denoised_output")
269
+
270
+ FUNCTION = "sample"
271
+
272
+ CATEGORY = "sampling/custom_sampling"
273
+
274
+ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image):
275
+ latent = latent_image
276
+ latent_image = latent["samples"]
277
+ if not add_noise:
278
+ noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
279
+ else:
280
+ batch_inds = latent["batch_index"] if "batch_index" in latent else None
281
+ noise = ldm_patched.modules.sample.prepare_noise(latent_image, noise_seed, batch_inds)
282
+
283
+ noise_mask = None
284
+ if "noise_mask" in latent:
285
+ noise_mask = latent["noise_mask"]
286
+
287
+ x0_output = {}
288
+ callback = ldm_patched.utils.latent_visualization.prepare_callback(model, sigmas.shape[-1] - 1, x0_output)
289
+
290
+ disable_pbar = not ldm_patched.modules.utils.PROGRESS_BAR_ENABLED
291
+ samples = ldm_patched.modules.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
292
+
293
+ out = latent.copy()
294
+ out["samples"] = samples
295
+ if "x0" in x0_output:
296
+ out_denoised = latent.copy()
297
+ out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu())
298
+ else:
299
+ out_denoised = out
300
+ return (out, out_denoised)
301
+
302
+ NODE_CLASS_MAPPINGS = {
303
+ "SamplerCustom": SamplerCustom,
304
+ "BasicScheduler": BasicScheduler,
305
+ "KarrasScheduler": KarrasScheduler,
306
+ "ExponentialScheduler": ExponentialScheduler,
307
+ "PolyexponentialScheduler": PolyexponentialScheduler,
308
+ "VPScheduler": VPScheduler,
309
+ "SDTurboScheduler": SDTurboScheduler,
310
+ "KSamplerSelect": KSamplerSelect,
311
+ "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
312
+ "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
313
+ "SamplerTCD": SamplerTCD,
314
+ "SplitSigmas": SplitSigmas,
315
+ "FlipSigmas": FlipSigmas,
316
+ }
ldm_patched/contrib/external_freelunch.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License)
4
+
5
+ import torch
6
+
7
+
8
+ def Fourier_filter(x, threshold, scale):
9
+ # FFT
10
+ x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
11
+ x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
12
+
13
+ B, C, H, W = x_freq.shape
14
+ mask = torch.ones((B, C, H, W), device=x.device)
15
+
16
+ crow, ccol = H // 2, W //2
17
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
18
+ x_freq = x_freq * mask
19
+
20
+ # IFFT
21
+ x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
22
+ x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
23
+
24
+ return x_filtered.to(x.dtype)
25
+
26
+
27
+ class FreeU:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {"required": { "model": ("MODEL",),
31
+ "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
32
+ "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
33
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
34
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
35
+ }}
36
+ RETURN_TYPES = ("MODEL",)
37
+ FUNCTION = "patch"
38
+
39
+ CATEGORY = "model_patches"
40
+
41
+ def patch(self, model, b1, b2, s1, s2):
42
+ model_channels = model.model.model_config.unet_config["model_channels"]
43
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
44
+ on_cpu_devices = {}
45
+
46
+ def output_block_patch(h, hsp, transformer_options):
47
+ scale = scale_dict.get(h.shape[1], None)
48
+ if scale is not None:
49
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0]
50
+ if hsp.device not in on_cpu_devices:
51
+ try:
52
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
53
+ except:
54
+ print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
55
+ on_cpu_devices[hsp.device] = True
56
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
57
+ else:
58
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
59
+
60
+ return h, hsp
61
+
62
+ m = model.clone()
63
+ m.set_model_output_block_patch(output_block_patch)
64
+ return (m, )
65
+
66
+ class FreeU_V2:
67
+ @classmethod
68
+ def INPUT_TYPES(s):
69
+ return {"required": { "model": ("MODEL",),
70
+ "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
71
+ "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
72
+ "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
73
+ "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
74
+ }}
75
+ RETURN_TYPES = ("MODEL",)
76
+ FUNCTION = "patch"
77
+
78
+ CATEGORY = "model_patches"
79
+
80
+ def patch(self, model, b1, b2, s1, s2):
81
+ model_channels = model.model.model_config.unet_config["model_channels"]
82
+ scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
83
+ on_cpu_devices = {}
84
+
85
+ def output_block_patch(h, hsp, transformer_options):
86
+ scale = scale_dict.get(h.shape[1], None)
87
+ if scale is not None:
88
+ hidden_mean = h.mean(1).unsqueeze(1)
89
+ B = hidden_mean.shape[0]
90
+ hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
91
+ hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
92
+ hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
93
+
94
+ h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1)
95
+
96
+ if hsp.device not in on_cpu_devices:
97
+ try:
98
+ hsp = Fourier_filter(hsp, threshold=1, scale=scale[1])
99
+ except:
100
+ print("Device", hsp.device, "does not support the torch.fft functions used in the FreeU node, switching to CPU.")
101
+ on_cpu_devices[hsp.device] = True
102
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
103
+ else:
104
+ hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device)
105
+
106
+ return h, hsp
107
+
108
+ m = model.clone()
109
+ m.set_model_output_block_patch(output_block_patch)
110
+ return (m, )
111
+
112
+ NODE_CLASS_MAPPINGS = {
113
+ "FreeU": FreeU,
114
+ "FreeU_V2": FreeU_V2,
115
+ }
ldm_patched/contrib/external_hypernetwork.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.modules.utils
4
+ import ldm_patched.utils.path_utils
5
+ import torch
6
+
7
+ def load_hypernetwork_patch(path, strength):
8
+ sd = ldm_patched.modules.utils.load_torch_file(path, safe_load=True)
9
+ activation_func = sd.get('activation_func', 'linear')
10
+ is_layer_norm = sd.get('is_layer_norm', False)
11
+ use_dropout = sd.get('use_dropout', False)
12
+ activate_output = sd.get('activate_output', False)
13
+ last_layer_dropout = sd.get('last_layer_dropout', False)
14
+
15
+ valid_activation = {
16
+ "linear": torch.nn.Identity,
17
+ "relu": torch.nn.ReLU,
18
+ "leakyrelu": torch.nn.LeakyReLU,
19
+ "elu": torch.nn.ELU,
20
+ "swish": torch.nn.Hardswish,
21
+ "tanh": torch.nn.Tanh,
22
+ "sigmoid": torch.nn.Sigmoid,
23
+ "softsign": torch.nn.Softsign,
24
+ "mish": torch.nn.Mish,
25
+ }
26
+
27
+ if activation_func not in valid_activation:
28
+ print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)
29
+ return None
30
+
31
+ out = {}
32
+
33
+ for d in sd:
34
+ try:
35
+ dim = int(d)
36
+ except:
37
+ continue
38
+
39
+ output = []
40
+ for index in [0, 1]:
41
+ attn_weights = sd[dim][index]
42
+ keys = attn_weights.keys()
43
+
44
+ linears = filter(lambda a: a.endswith(".weight"), keys)
45
+ linears = list(map(lambda a: a[:-len(".weight")], linears))
46
+ layers = []
47
+
48
+ i = 0
49
+ while i < len(linears):
50
+ lin_name = linears[i]
51
+ last_layer = (i == (len(linears) - 1))
52
+ penultimate_layer = (i == (len(linears) - 2))
53
+
54
+ lin_weight = attn_weights['{}.weight'.format(lin_name)]
55
+ lin_bias = attn_weights['{}.bias'.format(lin_name)]
56
+ layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0])
57
+ layer.load_state_dict({"weight": lin_weight, "bias": lin_bias})
58
+ layers.append(layer)
59
+ if activation_func != "linear":
60
+ if (not last_layer) or (activate_output):
61
+ layers.append(valid_activation[activation_func]())
62
+ if is_layer_norm:
63
+ i += 1
64
+ ln_name = linears[i]
65
+ ln_weight = attn_weights['{}.weight'.format(ln_name)]
66
+ ln_bias = attn_weights['{}.bias'.format(ln_name)]
67
+ ln = torch.nn.LayerNorm(ln_weight.shape[0])
68
+ ln.load_state_dict({"weight": ln_weight, "bias": ln_bias})
69
+ layers.append(ln)
70
+ if use_dropout:
71
+ if (not last_layer) and (not penultimate_layer or last_layer_dropout):
72
+ layers.append(torch.nn.Dropout(p=0.3))
73
+ i += 1
74
+
75
+ output.append(torch.nn.Sequential(*layers))
76
+ out[dim] = torch.nn.ModuleList(output)
77
+
78
+ class hypernetwork_patch:
79
+ def __init__(self, hypernet, strength):
80
+ self.hypernet = hypernet
81
+ self.strength = strength
82
+ def __call__(self, q, k, v, extra_options):
83
+ dim = k.shape[-1]
84
+ if dim in self.hypernet:
85
+ hn = self.hypernet[dim]
86
+ k = k + hn[0](k) * self.strength
87
+ v = v + hn[1](v) * self.strength
88
+
89
+ return q, k, v
90
+
91
+ def to(self, device):
92
+ for d in self.hypernet.keys():
93
+ self.hypernet[d] = self.hypernet[d].to(device)
94
+ return self
95
+
96
+ return hypernetwork_patch(out, strength)
97
+
98
+ class HypernetworkLoader:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": { "model": ("MODEL",),
102
+ "hypernetwork_name": (ldm_patched.utils.path_utils.get_filename_list("hypernetworks"), ),
103
+ "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
104
+ }}
105
+ RETURN_TYPES = ("MODEL",)
106
+ FUNCTION = "load_hypernetwork"
107
+
108
+ CATEGORY = "loaders"
109
+
110
+ def load_hypernetwork(self, model, hypernetwork_name, strength):
111
+ hypernetwork_path = ldm_patched.utils.path_utils.get_full_path("hypernetworks", hypernetwork_name)
112
+ model_hypernetwork = model.clone()
113
+ patch = load_hypernetwork_patch(hypernetwork_path, strength)
114
+ if patch is not None:
115
+ model_hypernetwork.set_model_attn1_patch(patch)
116
+ model_hypernetwork.set_model_attn2_patch(patch)
117
+ return (model_hypernetwork,)
118
+
119
+ NODE_CLASS_MAPPINGS = {
120
+ "HypernetworkLoader": HypernetworkLoader
121
+ }
ldm_patched/contrib/external_hypertile.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #Taken from: https://github.com/tfernd/HyperTile/
4
+
5
+ import math
6
+ from einops import rearrange
7
+ # Use torch rng for consistency across generations
8
+ from torch import randint
9
+
10
+ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
11
+ min_value = min(min_value, value)
12
+
13
+ # All big divisors of value (inclusive)
14
+ divisors = [i for i in range(min_value, value + 1) if value % i == 0]
15
+
16
+ ns = [value // i for i in divisors[:max_options]] # has at least 1 element
17
+
18
+ if len(ns) - 1 > 0:
19
+ idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
20
+ else:
21
+ idx = 0
22
+
23
+ return ns[idx]
24
+
25
+ class HyperTile:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "model": ("MODEL",),
29
+ "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
30
+ "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
31
+ "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
32
+ "scale_depth": ("BOOLEAN", {"default": False}),
33
+ }}
34
+ RETURN_TYPES = ("MODEL",)
35
+ FUNCTION = "patch"
36
+
37
+ CATEGORY = "model_patches"
38
+
39
+ def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
40
+ model_channels = model.model.model_config.unet_config["model_channels"]
41
+
42
+ latent_tile_size = max(32, tile_size) // 8
43
+ self.temp = None
44
+
45
+ def hypertile_in(q, k, v, extra_options):
46
+ model_chans = q.shape[-2]
47
+ orig_shape = extra_options['original_shape']
48
+ apply_to = []
49
+ for i in range(max_depth + 1):
50
+ apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i)))
51
+
52
+ if model_chans in apply_to:
53
+ shape = extra_options["original_shape"]
54
+ aspect_ratio = shape[-1] / shape[-2]
55
+
56
+ hw = q.size(1)
57
+ h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
58
+
59
+ factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1
60
+ nh = random_divisor(h, latent_tile_size * factor, swap_size)
61
+ nw = random_divisor(w, latent_tile_size * factor, swap_size)
62
+
63
+ if nh * nw > 1:
64
+ q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
65
+ self.temp = (nh, nw, h, w)
66
+ return q, k, v
67
+
68
+ return q, k, v
69
+ def hypertile_out(out, extra_options):
70
+ if self.temp is not None:
71
+ nh, nw, h, w = self.temp
72
+ self.temp = None
73
+ out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
74
+ out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
75
+ return out
76
+
77
+
78
+ m = model.clone()
79
+ m.set_model_attn1_patch(hypertile_in)
80
+ m.set_model_attn1_output_patch(hypertile_out)
81
+ return (m, )
82
+
83
+ NODE_CLASS_MAPPINGS = {
84
+ "HyperTile": HyperTile,
85
+ }
ldm_patched/contrib/external_images.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.contrib.external
4
+ import ldm_patched.utils.path_utils
5
+ from ldm_patched.modules.args_parser import args
6
+
7
+ from PIL import Image
8
+ from PIL.PngImagePlugin import PngInfo
9
+
10
+ import numpy as np
11
+ import json
12
+ import os
13
+
14
+ MAX_RESOLUTION = ldm_patched.contrib.external.MAX_RESOLUTION
15
+
16
+ class ImageCrop:
17
+ @classmethod
18
+ def INPUT_TYPES(s):
19
+ return {"required": { "image": ("IMAGE",),
20
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
21
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
22
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
23
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
24
+ }}
25
+ RETURN_TYPES = ("IMAGE",)
26
+ FUNCTION = "crop"
27
+
28
+ CATEGORY = "image/transform"
29
+
30
+ def crop(self, image, width, height, x, y):
31
+ x = min(x, image.shape[2] - 1)
32
+ y = min(y, image.shape[1] - 1)
33
+ to_x = width + x
34
+ to_y = height + y
35
+ img = image[:,y:to_y, x:to_x, :]
36
+ return (img,)
37
+
38
+ class RepeatImageBatch:
39
+ @classmethod
40
+ def INPUT_TYPES(s):
41
+ return {"required": { "image": ("IMAGE",),
42
+ "amount": ("INT", {"default": 1, "min": 1, "max": 64}),
43
+ }}
44
+ RETURN_TYPES = ("IMAGE",)
45
+ FUNCTION = "repeat"
46
+
47
+ CATEGORY = "image/batch"
48
+
49
+ def repeat(self, image, amount):
50
+ s = image.repeat((amount, 1,1,1))
51
+ return (s,)
52
+
53
+ class SaveAnimatedWEBP:
54
+ def __init__(self):
55
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
56
+ self.type = "output"
57
+ self.prefix_append = ""
58
+
59
+ methods = {"default": 4, "fastest": 0, "slowest": 6}
60
+ @classmethod
61
+ def INPUT_TYPES(s):
62
+ return {"required":
63
+ {"images": ("IMAGE", ),
64
+ "filename_prefix": ("STRING", {"default": "ldm_patched"}),
65
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
66
+ "lossless": ("BOOLEAN", {"default": True}),
67
+ "quality": ("INT", {"default": 80, "min": 0, "max": 100}),
68
+ "method": (list(s.methods.keys()),),
69
+ # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
70
+ },
71
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
72
+ }
73
+
74
+ RETURN_TYPES = ()
75
+ FUNCTION = "save_images"
76
+
77
+ OUTPUT_NODE = True
78
+
79
+ CATEGORY = "image/animation"
80
+
81
+ def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
82
+ method = self.methods.get(method)
83
+ filename_prefix += self.prefix_append
84
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
85
+ results = list()
86
+ pil_images = []
87
+ for image in images:
88
+ i = 255. * image.cpu().numpy()
89
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
90
+ pil_images.append(img)
91
+
92
+ metadata = pil_images[0].getexif()
93
+ if not args.disable_server_info:
94
+ if prompt is not None:
95
+ metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
96
+ if extra_pnginfo is not None:
97
+ inital_exif = 0x010f
98
+ for x in extra_pnginfo:
99
+ metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
100
+ inital_exif -= 1
101
+
102
+ if num_frames == 0:
103
+ num_frames = len(pil_images)
104
+
105
+ c = len(pil_images)
106
+ for i in range(0, c, num_frames):
107
+ file = f"{filename}_{counter:05}_.webp"
108
+ pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
109
+ results.append({
110
+ "filename": file,
111
+ "subfolder": subfolder,
112
+ "type": self.type
113
+ })
114
+ counter += 1
115
+
116
+ animated = num_frames != 1
117
+ return { "ui": { "images": results, "animated": (animated,) } }
118
+
119
+ class SaveAnimatedPNG:
120
+ def __init__(self):
121
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
122
+ self.type = "output"
123
+ self.prefix_append = ""
124
+
125
+ @classmethod
126
+ def INPUT_TYPES(s):
127
+ return {"required":
128
+ {"images": ("IMAGE", ),
129
+ "filename_prefix": ("STRING", {"default": "ldm_patched"}),
130
+ "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
131
+ "compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
132
+ },
133
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
134
+ }
135
+
136
+ RETURN_TYPES = ()
137
+ FUNCTION = "save_images"
138
+
139
+ OUTPUT_NODE = True
140
+
141
+ CATEGORY = "image/animation"
142
+
143
+ def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
144
+ filename_prefix += self.prefix_append
145
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
146
+ results = list()
147
+ pil_images = []
148
+ for image in images:
149
+ i = 255. * image.cpu().numpy()
150
+ img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
151
+ pil_images.append(img)
152
+
153
+ metadata = None
154
+ if not args.disable_server_info:
155
+ metadata = PngInfo()
156
+ if prompt is not None:
157
+ metadata.add(b"ldm_patched", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True)
158
+ if extra_pnginfo is not None:
159
+ for x in extra_pnginfo:
160
+ metadata.add(b"ldm_patched", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
161
+
162
+ file = f"{filename}_{counter:05}_.png"
163
+ pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
164
+ results.append({
165
+ "filename": file,
166
+ "subfolder": subfolder,
167
+ "type": self.type
168
+ })
169
+
170
+ return { "ui": { "images": results, "animated": (True,)} }
171
+
172
+ NODE_CLASS_MAPPINGS = {
173
+ "ImageCrop": ImageCrop,
174
+ "RepeatImageBatch": RepeatImageBatch,
175
+ "SaveAnimatedWEBP": SaveAnimatedWEBP,
176
+ "SaveAnimatedPNG": SaveAnimatedPNG,
177
+ }
ldm_patched/contrib/external_latent.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.modules.utils
4
+ import torch
5
+
6
+ def reshape_latent_to(target_shape, latent):
7
+ if latent.shape[1:] != target_shape[1:]:
8
+ latent = ldm_patched.modules.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
9
+ return ldm_patched.modules.utils.repeat_to_batch_size(latent, target_shape[0])
10
+
11
+
12
+ class LatentAdd:
13
+ @classmethod
14
+ def INPUT_TYPES(s):
15
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
16
+
17
+ RETURN_TYPES = ("LATENT",)
18
+ FUNCTION = "op"
19
+
20
+ CATEGORY = "latent/advanced"
21
+
22
+ def op(self, samples1, samples2):
23
+ samples_out = samples1.copy()
24
+
25
+ s1 = samples1["samples"]
26
+ s2 = samples2["samples"]
27
+
28
+ s2 = reshape_latent_to(s1.shape, s2)
29
+ samples_out["samples"] = s1 + s2
30
+ return (samples_out,)
31
+
32
+ class LatentSubtract:
33
+ @classmethod
34
+ def INPUT_TYPES(s):
35
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
36
+
37
+ RETURN_TYPES = ("LATENT",)
38
+ FUNCTION = "op"
39
+
40
+ CATEGORY = "latent/advanced"
41
+
42
+ def op(self, samples1, samples2):
43
+ samples_out = samples1.copy()
44
+
45
+ s1 = samples1["samples"]
46
+ s2 = samples2["samples"]
47
+
48
+ s2 = reshape_latent_to(s1.shape, s2)
49
+ samples_out["samples"] = s1 - s2
50
+ return (samples_out,)
51
+
52
+ class LatentMultiply:
53
+ @classmethod
54
+ def INPUT_TYPES(s):
55
+ return {"required": { "samples": ("LATENT",),
56
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
57
+ }}
58
+
59
+ RETURN_TYPES = ("LATENT",)
60
+ FUNCTION = "op"
61
+
62
+ CATEGORY = "latent/advanced"
63
+
64
+ def op(self, samples, multiplier):
65
+ samples_out = samples.copy()
66
+
67
+ s1 = samples["samples"]
68
+ samples_out["samples"] = s1 * multiplier
69
+ return (samples_out,)
70
+
71
+ class LatentInterpolate:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {"required": { "samples1": ("LATENT",),
75
+ "samples2": ("LATENT",),
76
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
77
+ }}
78
+
79
+ RETURN_TYPES = ("LATENT",)
80
+ FUNCTION = "op"
81
+
82
+ CATEGORY = "latent/advanced"
83
+
84
+ def op(self, samples1, samples2, ratio):
85
+ samples_out = samples1.copy()
86
+
87
+ s1 = samples1["samples"]
88
+ s2 = samples2["samples"]
89
+
90
+ s2 = reshape_latent_to(s1.shape, s2)
91
+
92
+ m1 = torch.linalg.vector_norm(s1, dim=(1))
93
+ m2 = torch.linalg.vector_norm(s2, dim=(1))
94
+
95
+ s1 = torch.nan_to_num(s1 / m1)
96
+ s2 = torch.nan_to_num(s2 / m2)
97
+
98
+ t = (s1 * ratio + s2 * (1.0 - ratio))
99
+ mt = torch.linalg.vector_norm(t, dim=(1))
100
+ st = torch.nan_to_num(t / mt)
101
+
102
+ samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
103
+ return (samples_out,)
104
+
105
+ class LatentBatch:
106
+ @classmethod
107
+ def INPUT_TYPES(s):
108
+ return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
109
+
110
+ RETURN_TYPES = ("LATENT",)
111
+ FUNCTION = "batch"
112
+
113
+ CATEGORY = "latent/batch"
114
+
115
+ def batch(self, samples1, samples2):
116
+ samples_out = samples1.copy()
117
+ s1 = samples1["samples"]
118
+ s2 = samples2["samples"]
119
+
120
+ if s1.shape[1:] != s2.shape[1:]:
121
+ s2 = ldm_patched.modules.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center")
122
+ s = torch.cat((s1, s2), dim=0)
123
+ samples_out["samples"] = s
124
+ samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
125
+ return (samples_out,)
126
+
127
+ class LatentBatchSeedBehavior:
128
+ @classmethod
129
+ def INPUT_TYPES(s):
130
+ return {"required": { "samples": ("LATENT",),
131
+ "seed_behavior": (["random", "fixed"],),}}
132
+
133
+ RETURN_TYPES = ("LATENT",)
134
+ FUNCTION = "op"
135
+
136
+ CATEGORY = "latent/advanced"
137
+
138
+ def op(self, samples, seed_behavior):
139
+ samples_out = samples.copy()
140
+ latent = samples["samples"]
141
+ if seed_behavior == "random":
142
+ if 'batch_index' in samples_out:
143
+ samples_out.pop('batch_index')
144
+ elif seed_behavior == "fixed":
145
+ batch_number = samples_out.get("batch_index", [0])[0]
146
+ samples_out["batch_index"] = [batch_number] * latent.shape[0]
147
+
148
+ return (samples_out,)
149
+
150
+ NODE_CLASS_MAPPINGS = {
151
+ "LatentAdd": LatentAdd,
152
+ "LatentSubtract": LatentSubtract,
153
+ "LatentMultiply": LatentMultiply,
154
+ "LatentInterpolate": LatentInterpolate,
155
+ "LatentBatch": LatentBatch,
156
+ "LatentBatchSeedBehavior": LatentBatchSeedBehavior,
157
+ }
ldm_patched/contrib/external_mask.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import numpy as np
4
+ import scipy.ndimage
5
+ import torch
6
+ import ldm_patched.modules.utils
7
+
8
+ from ldm_patched.contrib.external import MAX_RESOLUTION
9
+
10
+ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
11
+ source = source.to(destination.device)
12
+ if resize_source:
13
+ source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
14
+
15
+ source = ldm_patched.modules.utils.repeat_to_batch_size(source, destination.shape[0])
16
+
17
+ x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
18
+ y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
19
+
20
+ left, top = (x // multiplier, y // multiplier)
21
+ right, bottom = (left + source.shape[3], top + source.shape[2],)
22
+
23
+ if mask is None:
24
+ mask = torch.ones_like(source)
25
+ else:
26
+ mask = mask.to(destination.device, copy=True)
27
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
28
+ mask = ldm_patched.modules.utils.repeat_to_batch_size(mask, source.shape[0])
29
+
30
+ # calculate the bounds of the source that will be overlapping the destination
31
+ # this prevents the source trying to overwrite latent pixels that are out of bounds
32
+ # of the destination
33
+ visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
34
+
35
+ mask = mask[:, :, :visible_height, :visible_width]
36
+ inverse_mask = torch.ones_like(mask) - mask
37
+
38
+ source_portion = mask * source[:, :, :visible_height, :visible_width]
39
+ destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
40
+
41
+ destination[:, :, top:bottom, left:right] = source_portion + destination_portion
42
+ return destination
43
+
44
+ class LatentCompositeMasked:
45
+ @classmethod
46
+ def INPUT_TYPES(s):
47
+ return {
48
+ "required": {
49
+ "destination": ("LATENT",),
50
+ "source": ("LATENT",),
51
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
52
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
53
+ "resize_source": ("BOOLEAN", {"default": False}),
54
+ },
55
+ "optional": {
56
+ "mask": ("MASK",),
57
+ }
58
+ }
59
+ RETURN_TYPES = ("LATENT",)
60
+ FUNCTION = "composite"
61
+
62
+ CATEGORY = "latent"
63
+
64
+ def composite(self, destination, source, x, y, resize_source, mask = None):
65
+ output = destination.copy()
66
+ destination = destination["samples"].clone()
67
+ source = source["samples"]
68
+ output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
69
+ return (output,)
70
+
71
+ class ImageCompositeMasked:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {
75
+ "required": {
76
+ "destination": ("IMAGE",),
77
+ "source": ("IMAGE",),
78
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
79
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
80
+ "resize_source": ("BOOLEAN", {"default": False}),
81
+ },
82
+ "optional": {
83
+ "mask": ("MASK",),
84
+ }
85
+ }
86
+ RETURN_TYPES = ("IMAGE",)
87
+ FUNCTION = "composite"
88
+
89
+ CATEGORY = "image"
90
+
91
+ def composite(self, destination, source, x, y, resize_source, mask = None):
92
+ destination = destination.clone().movedim(-1, 1)
93
+ output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
94
+ return (output,)
95
+
96
+ class MaskToImage:
97
+ @classmethod
98
+ def INPUT_TYPES(s):
99
+ return {
100
+ "required": {
101
+ "mask": ("MASK",),
102
+ }
103
+ }
104
+
105
+ CATEGORY = "mask"
106
+
107
+ RETURN_TYPES = ("IMAGE",)
108
+ FUNCTION = "mask_to_image"
109
+
110
+ def mask_to_image(self, mask):
111
+ result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
112
+ return (result,)
113
+
114
+ class ImageToMask:
115
+ @classmethod
116
+ def INPUT_TYPES(s):
117
+ return {
118
+ "required": {
119
+ "image": ("IMAGE",),
120
+ "channel": (["red", "green", "blue", "alpha"],),
121
+ }
122
+ }
123
+
124
+ CATEGORY = "mask"
125
+
126
+ RETURN_TYPES = ("MASK",)
127
+ FUNCTION = "image_to_mask"
128
+
129
+ def image_to_mask(self, image, channel):
130
+ channels = ["red", "green", "blue", "alpha"]
131
+ mask = image[:, :, :, channels.index(channel)]
132
+ return (mask,)
133
+
134
+ class ImageColorToMask:
135
+ @classmethod
136
+ def INPUT_TYPES(s):
137
+ return {
138
+ "required": {
139
+ "image": ("IMAGE",),
140
+ "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
141
+ }
142
+ }
143
+
144
+ CATEGORY = "mask"
145
+
146
+ RETURN_TYPES = ("MASK",)
147
+ FUNCTION = "image_to_mask"
148
+
149
+ def image_to_mask(self, image, color):
150
+ temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
151
+ temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
152
+ mask = torch.where(temp == color, 255, 0).float()
153
+ return (mask,)
154
+
155
+ class SolidMask:
156
+ @classmethod
157
+ def INPUT_TYPES(cls):
158
+ return {
159
+ "required": {
160
+ "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
161
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
162
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
163
+ }
164
+ }
165
+
166
+ CATEGORY = "mask"
167
+
168
+ RETURN_TYPES = ("MASK",)
169
+
170
+ FUNCTION = "solid"
171
+
172
+ def solid(self, value, width, height):
173
+ out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
174
+ return (out,)
175
+
176
+ class InvertMask:
177
+ @classmethod
178
+ def INPUT_TYPES(cls):
179
+ return {
180
+ "required": {
181
+ "mask": ("MASK",),
182
+ }
183
+ }
184
+
185
+ CATEGORY = "mask"
186
+
187
+ RETURN_TYPES = ("MASK",)
188
+
189
+ FUNCTION = "invert"
190
+
191
+ def invert(self, mask):
192
+ out = 1.0 - mask
193
+ return (out,)
194
+
195
+ class CropMask:
196
+ @classmethod
197
+ def INPUT_TYPES(cls):
198
+ return {
199
+ "required": {
200
+ "mask": ("MASK",),
201
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
202
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
203
+ "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
204
+ "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
205
+ }
206
+ }
207
+
208
+ CATEGORY = "mask"
209
+
210
+ RETURN_TYPES = ("MASK",)
211
+
212
+ FUNCTION = "crop"
213
+
214
+ def crop(self, mask, x, y, width, height):
215
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
216
+ out = mask[:, y:y + height, x:x + width]
217
+ return (out,)
218
+
219
+ class MaskComposite:
220
+ @classmethod
221
+ def INPUT_TYPES(cls):
222
+ return {
223
+ "required": {
224
+ "destination": ("MASK",),
225
+ "source": ("MASK",),
226
+ "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
227
+ "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
228
+ "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
229
+ }
230
+ }
231
+
232
+ CATEGORY = "mask"
233
+
234
+ RETURN_TYPES = ("MASK",)
235
+
236
+ FUNCTION = "combine"
237
+
238
+ def combine(self, destination, source, x, y, operation):
239
+ output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
240
+ source = source.reshape((-1, source.shape[-2], source.shape[-1]))
241
+
242
+ left, top = (x, y,)
243
+ right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
244
+ visible_width, visible_height = (right - left, bottom - top,)
245
+
246
+ source_portion = source[:, :visible_height, :visible_width]
247
+ destination_portion = destination[:, top:bottom, left:right]
248
+
249
+ if operation == "multiply":
250
+ output[:, top:bottom, left:right] = destination_portion * source_portion
251
+ elif operation == "add":
252
+ output[:, top:bottom, left:right] = destination_portion + source_portion
253
+ elif operation == "subtract":
254
+ output[:, top:bottom, left:right] = destination_portion - source_portion
255
+ elif operation == "and":
256
+ output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
257
+ elif operation == "or":
258
+ output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
259
+ elif operation == "xor":
260
+ output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
261
+
262
+ output = torch.clamp(output, 0.0, 1.0)
263
+
264
+ return (output,)
265
+
266
+ class FeatherMask:
267
+ @classmethod
268
+ def INPUT_TYPES(cls):
269
+ return {
270
+ "required": {
271
+ "mask": ("MASK",),
272
+ "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
273
+ "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
274
+ "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
275
+ "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
276
+ }
277
+ }
278
+
279
+ CATEGORY = "mask"
280
+
281
+ RETURN_TYPES = ("MASK",)
282
+
283
+ FUNCTION = "feather"
284
+
285
+ def feather(self, mask, left, top, right, bottom):
286
+ output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
287
+
288
+ left = min(left, output.shape[-1])
289
+ right = min(right, output.shape[-1])
290
+ top = min(top, output.shape[-2])
291
+ bottom = min(bottom, output.shape[-2])
292
+
293
+ for x in range(left):
294
+ feather_rate = (x + 1.0) / left
295
+ output[:, :, x] *= feather_rate
296
+
297
+ for x in range(right):
298
+ feather_rate = (x + 1) / right
299
+ output[:, :, -x] *= feather_rate
300
+
301
+ for y in range(top):
302
+ feather_rate = (y + 1) / top
303
+ output[:, y, :] *= feather_rate
304
+
305
+ for y in range(bottom):
306
+ feather_rate = (y + 1) / bottom
307
+ output[:, -y, :] *= feather_rate
308
+
309
+ return (output,)
310
+
311
+ class GrowMask:
312
+ @classmethod
313
+ def INPUT_TYPES(cls):
314
+ return {
315
+ "required": {
316
+ "mask": ("MASK",),
317
+ "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
318
+ "tapered_corners": ("BOOLEAN", {"default": True}),
319
+ },
320
+ }
321
+
322
+ CATEGORY = "mask"
323
+
324
+ RETURN_TYPES = ("MASK",)
325
+
326
+ FUNCTION = "expand_mask"
327
+
328
+ def expand_mask(self, mask, expand, tapered_corners):
329
+ c = 0 if tapered_corners else 1
330
+ kernel = np.array([[c, 1, c],
331
+ [1, 1, 1],
332
+ [c, 1, c]])
333
+ mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
334
+ out = []
335
+ for m in mask:
336
+ output = m.numpy()
337
+ for _ in range(abs(expand)):
338
+ if expand < 0:
339
+ output = scipy.ndimage.grey_erosion(output, footprint=kernel)
340
+ else:
341
+ output = scipy.ndimage.grey_dilation(output, footprint=kernel)
342
+ output = torch.from_numpy(output)
343
+ out.append(output)
344
+ return (torch.stack(out, dim=0),)
345
+
346
+
347
+
348
+ NODE_CLASS_MAPPINGS = {
349
+ "LatentCompositeMasked": LatentCompositeMasked,
350
+ "ImageCompositeMasked": ImageCompositeMasked,
351
+ "MaskToImage": MaskToImage,
352
+ "ImageToMask": ImageToMask,
353
+ "ImageColorToMask": ImageColorToMask,
354
+ "SolidMask": SolidMask,
355
+ "InvertMask": InvertMask,
356
+ "CropMask": CropMask,
357
+ "MaskComposite": MaskComposite,
358
+ "FeatherMask": FeatherMask,
359
+ "GrowMask": GrowMask,
360
+ }
361
+
362
+ NODE_DISPLAY_NAME_MAPPINGS = {
363
+ "ImageToMask": "Convert Image to Mask",
364
+ "MaskToImage": "Convert Mask to Image",
365
+ }
ldm_patched/contrib/external_model_advanced.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.utils.path_utils
4
+ import ldm_patched.modules.sd
5
+ import ldm_patched.modules.model_sampling
6
+ import torch
7
+
8
+ class LCM(ldm_patched.modules.model_sampling.EPS):
9
+ def calculate_denoised(self, sigma, model_output, model_input):
10
+ timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
11
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
12
+ x0 = model_input - model_output * sigma
13
+
14
+ sigma_data = 0.5
15
+ scaled_timestep = timestep * 10.0 #timestep_scaling
16
+
17
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
18
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
19
+
20
+ return c_out * x0 + c_skip * model_input
21
+
22
+ class ModelSamplingDiscreteDistilled(ldm_patched.modules.model_sampling.ModelSamplingDiscrete):
23
+ original_timesteps = 50
24
+
25
+ def __init__(self, model_config=None):
26
+ super().__init__(model_config)
27
+
28
+ self.skip_steps = self.num_timesteps // self.original_timesteps
29
+
30
+ sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
31
+ for x in range(self.original_timesteps):
32
+ sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
33
+
34
+ self.set_sigmas(sigmas_valid)
35
+
36
+ def timestep(self, sigma):
37
+ log_sigma = sigma.log()
38
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
39
+ return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device)
40
+
41
+ def sigma(self, timestep):
42
+ t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
43
+ low_idx = t.floor().long()
44
+ high_idx = t.ceil().long()
45
+ w = t.frac()
46
+ log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
47
+ return log_sigma.exp().to(timestep.device)
48
+
49
+
50
+ def rescale_zero_terminal_snr_sigmas(sigmas):
51
+ alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
52
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
53
+
54
+ # Store old values.
55
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
56
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
57
+
58
+ # Shift so the last timestep is zero.
59
+ alphas_bar_sqrt -= (alphas_bar_sqrt_T)
60
+
61
+ # Scale so the first timestep is back to the old value.
62
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
63
+
64
+ # Convert alphas_bar_sqrt to betas
65
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
66
+ alphas_bar[-1] = 4.8973451890853435e-08
67
+ return ((1 - alphas_bar) / alphas_bar) ** 0.5
68
+
69
+ class ModelSamplingDiscrete:
70
+ @classmethod
71
+ def INPUT_TYPES(s):
72
+ return {"required": { "model": ("MODEL",),
73
+ "sampling": (["eps", "v_prediction", "lcm", "tcd"]),
74
+ "zsnr": ("BOOLEAN", {"default": False}),
75
+ }}
76
+
77
+ RETURN_TYPES = ("MODEL",)
78
+ FUNCTION = "patch"
79
+
80
+ CATEGORY = "advanced/model"
81
+
82
+ def patch(self, model, sampling, zsnr):
83
+ m = model.clone()
84
+
85
+ sampling_base = ldm_patched.modules.model_sampling.ModelSamplingDiscrete
86
+ if sampling == "eps":
87
+ sampling_type = ldm_patched.modules.model_sampling.EPS
88
+ elif sampling == "v_prediction":
89
+ sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
90
+ elif sampling == "lcm":
91
+ sampling_type = LCM
92
+ sampling_base = ModelSamplingDiscreteDistilled
93
+ elif sampling == "tcd":
94
+ sampling_type = ldm_patched.modules.model_sampling.EPS
95
+ sampling_base = ModelSamplingDiscreteDistilled
96
+
97
+ class ModelSamplingAdvanced(sampling_base, sampling_type):
98
+ pass
99
+
100
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
101
+ if zsnr:
102
+ model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
103
+
104
+ m.add_object_patch("model_sampling", model_sampling)
105
+ return (m, )
106
+
107
+ class ModelSamplingContinuousEDM:
108
+ @classmethod
109
+ def INPUT_TYPES(s):
110
+ return {"required": { "model": ("MODEL",),
111
+ "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],),
112
+ "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
113
+ "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}),
114
+ }}
115
+
116
+ RETURN_TYPES = ("MODEL",)
117
+ FUNCTION = "patch"
118
+
119
+ CATEGORY = "advanced/model"
120
+
121
+ def patch(self, model, sampling, sigma_max, sigma_min):
122
+ m = model.clone()
123
+
124
+ latent_format = None
125
+ sigma_data = 1.0
126
+ if sampling == "eps":
127
+ sampling_type = ldm_patched.modules.model_sampling.EPS
128
+ elif sampling == "v_prediction":
129
+ sampling_type = ldm_patched.modules.model_sampling.V_PREDICTION
130
+ elif sampling == "edm_playground_v2.5":
131
+ sampling_type = ldm_patched.modules.model_sampling.EDM
132
+ sigma_data = 0.5
133
+ latent_format = ldm_patched.modules.latent_formats.SDXL_Playground_2_5()
134
+
135
+ class ModelSamplingAdvanced(ldm_patched.modules.model_sampling.ModelSamplingContinuousEDM, sampling_type):
136
+ pass
137
+
138
+ model_sampling = ModelSamplingAdvanced(model.model.model_config)
139
+ model_sampling.set_parameters(sigma_min, sigma_max, sigma_data)
140
+ m.add_object_patch("model_sampling", model_sampling)
141
+ if latent_format is not None:
142
+ m.add_object_patch("latent_format", latent_format)
143
+ return (m, )
144
+
145
+ class RescaleCFG:
146
+ @classmethod
147
+ def INPUT_TYPES(s):
148
+ return {"required": { "model": ("MODEL",),
149
+ "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
150
+ }}
151
+ RETURN_TYPES = ("MODEL",)
152
+ FUNCTION = "patch"
153
+
154
+ CATEGORY = "advanced/model"
155
+
156
+ def patch(self, model, multiplier):
157
+ def rescale_cfg(args):
158
+ cond = args["cond"]
159
+ uncond = args["uncond"]
160
+ cond_scale = args["cond_scale"]
161
+ sigma = args["sigma"]
162
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
163
+ x_orig = args["input"]
164
+
165
+ #rescale cfg has to be done on v-pred model output
166
+ x = x_orig / (sigma * sigma + 1.0)
167
+ cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
168
+ uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
169
+
170
+ #rescalecfg
171
+ x_cfg = uncond + cond_scale * (cond - uncond)
172
+ ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
173
+ ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
174
+
175
+ x_rescaled = x_cfg * (ro_pos / ro_cfg)
176
+ x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
177
+
178
+ return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
179
+
180
+ m = model.clone()
181
+ m.set_model_sampler_cfg_function(rescale_cfg)
182
+ return (m, )
183
+
184
+ NODE_CLASS_MAPPINGS = {
185
+ "ModelSamplingDiscrete": ModelSamplingDiscrete,
186
+ "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM,
187
+ "RescaleCFG": RescaleCFG,
188
+ }
ldm_patched/contrib/external_model_downscale.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.modules.utils
5
+
6
+ class PatchModelAddDownscale:
7
+ upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": { "model": ("MODEL",),
11
+ "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
12
+ "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
13
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
14
+ "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
15
+ "downscale_after_skip": ("BOOLEAN", {"default": True}),
16
+ "downscale_method": (s.upscale_methods,),
17
+ "upscale_method": (s.upscale_methods,),
18
+ }}
19
+ RETURN_TYPES = ("MODEL",)
20
+ FUNCTION = "patch"
21
+
22
+ CATEGORY = "_for_testing"
23
+
24
+ def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
25
+ sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
26
+ sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
27
+
28
+ def input_block_patch(h, transformer_options):
29
+ if transformer_options["block"][1] == block_number:
30
+ sigma = transformer_options["sigmas"][0].item()
31
+ if sigma <= sigma_start and sigma >= sigma_end:
32
+ h = ldm_patched.modules.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
33
+ return h
34
+
35
+ def output_block_patch(h, hsp, transformer_options):
36
+ if h.shape[2] != hsp.shape[2]:
37
+ h = ldm_patched.modules.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
38
+ return h, hsp
39
+
40
+ m = model.clone()
41
+ if downscale_after_skip:
42
+ m.set_model_input_block_patch_after_skip(input_block_patch)
43
+ else:
44
+ m.set_model_input_block_patch(input_block_patch)
45
+ m.set_model_output_block_patch(output_block_patch)
46
+ return (m, )
47
+
48
+ NODE_CLASS_MAPPINGS = {
49
+ "PatchModelAddDownscale": PatchModelAddDownscale,
50
+ }
51
+
52
+ NODE_DISPLAY_NAME_MAPPINGS = {
53
+ # Sampling
54
+ "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
55
+ }
ldm_patched/contrib/external_model_merging.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.modules.sd
4
+ import ldm_patched.modules.utils
5
+ import ldm_patched.modules.model_base
6
+ import ldm_patched.modules.model_management
7
+
8
+ import ldm_patched.utils.path_utils
9
+ import json
10
+ import os
11
+
12
+ from ldm_patched.modules.args_parser import args
13
+
14
+ class ModelMergeSimple:
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {"required": { "model1": ("MODEL",),
18
+ "model2": ("MODEL",),
19
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
20
+ }}
21
+ RETURN_TYPES = ("MODEL",)
22
+ FUNCTION = "merge"
23
+
24
+ CATEGORY = "advanced/model_merging"
25
+
26
+ def merge(self, model1, model2, ratio):
27
+ m = model1.clone()
28
+ kp = model2.get_key_patches("diffusion_model.")
29
+ for k in kp:
30
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
31
+ return (m, )
32
+
33
+ class ModelSubtract:
34
+ @classmethod
35
+ def INPUT_TYPES(s):
36
+ return {"required": { "model1": ("MODEL",),
37
+ "model2": ("MODEL",),
38
+ "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
39
+ }}
40
+ RETURN_TYPES = ("MODEL",)
41
+ FUNCTION = "merge"
42
+
43
+ CATEGORY = "advanced/model_merging"
44
+
45
+ def merge(self, model1, model2, multiplier):
46
+ m = model1.clone()
47
+ kp = model2.get_key_patches("diffusion_model.")
48
+ for k in kp:
49
+ m.add_patches({k: kp[k]}, - multiplier, multiplier)
50
+ return (m, )
51
+
52
+ class ModelAdd:
53
+ @classmethod
54
+ def INPUT_TYPES(s):
55
+ return {"required": { "model1": ("MODEL",),
56
+ "model2": ("MODEL",),
57
+ }}
58
+ RETURN_TYPES = ("MODEL",)
59
+ FUNCTION = "merge"
60
+
61
+ CATEGORY = "advanced/model_merging"
62
+
63
+ def merge(self, model1, model2):
64
+ m = model1.clone()
65
+ kp = model2.get_key_patches("diffusion_model.")
66
+ for k in kp:
67
+ m.add_patches({k: kp[k]}, 1.0, 1.0)
68
+ return (m, )
69
+
70
+
71
+ class CLIPMergeSimple:
72
+ @classmethod
73
+ def INPUT_TYPES(s):
74
+ return {"required": { "clip1": ("CLIP",),
75
+ "clip2": ("CLIP",),
76
+ "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
77
+ }}
78
+ RETURN_TYPES = ("CLIP",)
79
+ FUNCTION = "merge"
80
+
81
+ CATEGORY = "advanced/model_merging"
82
+
83
+ def merge(self, clip1, clip2, ratio):
84
+ m = clip1.clone()
85
+ kp = clip2.get_key_patches()
86
+ for k in kp:
87
+ if k.endswith(".position_ids") or k.endswith(".logit_scale"):
88
+ continue
89
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
90
+ return (m, )
91
+
92
+ class ModelMergeBlocks:
93
+ @classmethod
94
+ def INPUT_TYPES(s):
95
+ return {"required": { "model1": ("MODEL",),
96
+ "model2": ("MODEL",),
97
+ "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
98
+ "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
99
+ "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
100
+ }}
101
+ RETURN_TYPES = ("MODEL",)
102
+ FUNCTION = "merge"
103
+
104
+ CATEGORY = "advanced/model_merging"
105
+
106
+ def merge(self, model1, model2, **kwargs):
107
+ m = model1.clone()
108
+ kp = model2.get_key_patches("diffusion_model.")
109
+ default_ratio = next(iter(kwargs.values()))
110
+
111
+ for k in kp:
112
+ ratio = default_ratio
113
+ k_unet = k[len("diffusion_model."):]
114
+
115
+ last_arg_size = 0
116
+ for arg in kwargs:
117
+ if k_unet.startswith(arg) and last_arg_size < len(arg):
118
+ ratio = kwargs[arg]
119
+ last_arg_size = len(arg)
120
+
121
+ m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
122
+ return (m, )
123
+
124
+ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None):
125
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, output_dir)
126
+ prompt_info = ""
127
+ if prompt is not None:
128
+ prompt_info = json.dumps(prompt)
129
+
130
+ metadata = {}
131
+
132
+ enable_modelspec = True
133
+ if isinstance(model.model, ldm_patched.modules.model_base.SDXL):
134
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
135
+ elif isinstance(model.model, ldm_patched.modules.model_base.SDXLRefiner):
136
+ metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
137
+ else:
138
+ enable_modelspec = False
139
+
140
+ if enable_modelspec:
141
+ metadata["modelspec.sai_model_spec"] = "1.0.0"
142
+ metadata["modelspec.implementation"] = "sgm"
143
+ metadata["modelspec.title"] = "{} {}".format(filename, counter)
144
+
145
+ #TODO:
146
+ # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
147
+ # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
148
+ # "v2-inpainting"
149
+
150
+ if model.model.model_type == ldm_patched.modules.model_base.ModelType.EPS:
151
+ metadata["modelspec.predict_key"] = "epsilon"
152
+ elif model.model.model_type == ldm_patched.modules.model_base.ModelType.V_PREDICTION:
153
+ metadata["modelspec.predict_key"] = "v"
154
+
155
+ if not args.disable_server_info:
156
+ metadata["prompt"] = prompt_info
157
+ if extra_pnginfo is not None:
158
+ for x in extra_pnginfo:
159
+ metadata[x] = json.dumps(extra_pnginfo[x])
160
+
161
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
162
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
163
+
164
+ ldm_patched.modules.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata)
165
+
166
+ class CheckpointSave:
167
+ def __init__(self):
168
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
169
+
170
+ @classmethod
171
+ def INPUT_TYPES(s):
172
+ return {"required": { "model": ("MODEL",),
173
+ "clip": ("CLIP",),
174
+ "vae": ("VAE",),
175
+ "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),},
176
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
177
+ RETURN_TYPES = ()
178
+ FUNCTION = "save"
179
+ OUTPUT_NODE = True
180
+
181
+ CATEGORY = "advanced/model_merging"
182
+
183
+ def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
184
+ save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
185
+ return {}
186
+
187
+ class CLIPSave:
188
+ def __init__(self):
189
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
190
+
191
+ @classmethod
192
+ def INPUT_TYPES(s):
193
+ return {"required": { "clip": ("CLIP",),
194
+ "filename_prefix": ("STRING", {"default": "clip/ldm_patched"}),},
195
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
196
+ RETURN_TYPES = ()
197
+ FUNCTION = "save"
198
+ OUTPUT_NODE = True
199
+
200
+ CATEGORY = "advanced/model_merging"
201
+
202
+ def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
203
+ prompt_info = ""
204
+ if prompt is not None:
205
+ prompt_info = json.dumps(prompt)
206
+
207
+ metadata = {}
208
+ if not args.disable_server_info:
209
+ metadata["prompt"] = prompt_info
210
+ if extra_pnginfo is not None:
211
+ for x in extra_pnginfo:
212
+ metadata[x] = json.dumps(extra_pnginfo[x])
213
+
214
+ ldm_patched.modules.model_management.load_models_gpu([clip.load_model()])
215
+ clip_sd = clip.get_sd()
216
+
217
+ for prefix in ["clip_l.", "clip_g.", ""]:
218
+ k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
219
+ current_clip_sd = {}
220
+ for x in k:
221
+ current_clip_sd[x] = clip_sd.pop(x)
222
+ if len(current_clip_sd) == 0:
223
+ continue
224
+
225
+ p = prefix[:-1]
226
+ replace_prefix = {}
227
+ filename_prefix_ = filename_prefix
228
+ if len(p) > 0:
229
+ filename_prefix_ = "{}_{}".format(filename_prefix_, p)
230
+ replace_prefix[prefix] = ""
231
+ replace_prefix["transformer."] = ""
232
+
233
+ full_output_folder, filename, counter, subfolder, filename_prefix_ = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix_, self.output_dir)
234
+
235
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
236
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
237
+
238
+ current_clip_sd = ldm_patched.modules.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
239
+
240
+ ldm_patched.modules.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
241
+ return {}
242
+
243
+ class VAESave:
244
+ def __init__(self):
245
+ self.output_dir = ldm_patched.utils.path_utils.get_output_directory()
246
+
247
+ @classmethod
248
+ def INPUT_TYPES(s):
249
+ return {"required": { "vae": ("VAE",),
250
+ "filename_prefix": ("STRING", {"default": "vae/ldm_patched_vae"}),},
251
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
252
+ RETURN_TYPES = ()
253
+ FUNCTION = "save"
254
+ OUTPUT_NODE = True
255
+
256
+ CATEGORY = "advanced/model_merging"
257
+
258
+ def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
259
+ full_output_folder, filename, counter, subfolder, filename_prefix = ldm_patched.utils.path_utils.get_save_image_path(filename_prefix, self.output_dir)
260
+ prompt_info = ""
261
+ if prompt is not None:
262
+ prompt_info = json.dumps(prompt)
263
+
264
+ metadata = {}
265
+ if not args.disable_server_info:
266
+ metadata["prompt"] = prompt_info
267
+ if extra_pnginfo is not None:
268
+ for x in extra_pnginfo:
269
+ metadata[x] = json.dumps(extra_pnginfo[x])
270
+
271
+ output_checkpoint = f"{filename}_{counter:05}_.safetensors"
272
+ output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
273
+
274
+ ldm_patched.modules.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
275
+ return {}
276
+
277
+ NODE_CLASS_MAPPINGS = {
278
+ "ModelMergeSimple": ModelMergeSimple,
279
+ "ModelMergeBlocks": ModelMergeBlocks,
280
+ "ModelMergeSubtract": ModelSubtract,
281
+ "ModelMergeAdd": ModelAdd,
282
+ "CheckpointSave": CheckpointSave,
283
+ "CLIPMergeSimple": CLIPMergeSimple,
284
+ "CLIPSave": CLIPSave,
285
+ "VAESave": VAESave,
286
+ }
ldm_patched/contrib/external_perpneg.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.modules.model_management
5
+ import ldm_patched.modules.sample
6
+ import ldm_patched.modules.samplers
7
+ import ldm_patched.modules.utils
8
+
9
+
10
+ class PerpNeg:
11
+ @classmethod
12
+ def INPUT_TYPES(s):
13
+ return {"required": {"model": ("MODEL", ),
14
+ "empty_conditioning": ("CONDITIONING", ),
15
+ "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}),
16
+ }}
17
+ RETURN_TYPES = ("MODEL",)
18
+ FUNCTION = "patch"
19
+
20
+ CATEGORY = "_for_testing"
21
+
22
+ def patch(self, model, empty_conditioning, neg_scale):
23
+ m = model.clone()
24
+ nocond = ldm_patched.modules.sample.convert_cond(empty_conditioning)
25
+
26
+ def cfg_function(args):
27
+ model = args["model"]
28
+ noise_pred_pos = args["cond_denoised"]
29
+ noise_pred_neg = args["uncond_denoised"]
30
+ cond_scale = args["cond_scale"]
31
+ x = args["input"]
32
+ sigma = args["sigma"]
33
+ model_options = args["model_options"]
34
+ nocond_processed = ldm_patched.modules.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative")
35
+
36
+ (noise_pred_nocond, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, nocond_processed, None, x, sigma, model_options)
37
+
38
+ pos = noise_pred_pos - noise_pred_nocond
39
+ neg = noise_pred_neg - noise_pred_nocond
40
+ perp = ((torch.mul(pos, neg).sum())/(torch.norm(neg)**2)) * neg
41
+ perp_neg = perp * neg_scale
42
+ cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg)
43
+ cfg_result = x - cfg_result
44
+ return cfg_result
45
+
46
+ m.set_model_sampler_cfg_function(cfg_function)
47
+
48
+ return (m, )
49
+
50
+
51
+ NODE_CLASS_MAPPINGS = {
52
+ "PerpNeg": PerpNeg,
53
+ }
54
+
55
+ NODE_DISPLAY_NAME_MAPPINGS = {
56
+ "PerpNeg": "Perp-Neg",
57
+ }
ldm_patched/contrib/external_photomaker.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import ldm_patched.utils.path_utils
6
+ import ldm_patched.modules.clip_model
7
+ import ldm_patched.modules.clip_vision
8
+ import ldm_patched.modules.ops
9
+
10
+ # code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0
11
+ VISION_CONFIG_DICT = {
12
+ "hidden_size": 1024,
13
+ "image_size": 224,
14
+ "intermediate_size": 4096,
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 24,
18
+ "patch_size": 14,
19
+ "projection_dim": 768,
20
+ "hidden_act": "quick_gelu",
21
+ }
22
+
23
+ class MLP(nn.Module):
24
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=ldm_patched.modules.ops):
25
+ super().__init__()
26
+ if use_residual:
27
+ assert in_dim == out_dim
28
+ self.layernorm = operations.LayerNorm(in_dim)
29
+ self.fc1 = operations.Linear(in_dim, hidden_dim)
30
+ self.fc2 = operations.Linear(hidden_dim, out_dim)
31
+ self.use_residual = use_residual
32
+ self.act_fn = nn.GELU()
33
+
34
+ def forward(self, x):
35
+ residual = x
36
+ x = self.layernorm(x)
37
+ x = self.fc1(x)
38
+ x = self.act_fn(x)
39
+ x = self.fc2(x)
40
+ if self.use_residual:
41
+ x = x + residual
42
+ return x
43
+
44
+
45
+ class FuseModule(nn.Module):
46
+ def __init__(self, embed_dim, operations):
47
+ super().__init__()
48
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations)
49
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations)
50
+ self.layer_norm = operations.LayerNorm(embed_dim)
51
+
52
+ def fuse_fn(self, prompt_embeds, id_embeds):
53
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
54
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
55
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
56
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
57
+ return stacked_id_embeds
58
+
59
+ def forward(
60
+ self,
61
+ prompt_embeds,
62
+ id_embeds,
63
+ class_tokens_mask,
64
+ ) -> torch.Tensor:
65
+ # id_embeds shape: [b, max_num_inputs, 1, 2048]
66
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
67
+ num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case
68
+ batch_size, max_num_inputs = id_embeds.shape[:2]
69
+ # seq_length: 77
70
+ seq_length = prompt_embeds.shape[1]
71
+ # flat_id_embeds shape: [b*max_num_inputs, 1, 2048]
72
+ flat_id_embeds = id_embeds.view(
73
+ -1, id_embeds.shape[-2], id_embeds.shape[-1]
74
+ )
75
+ # valid_id_mask [b*max_num_inputs]
76
+ valid_id_mask = (
77
+ torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :]
78
+ < num_inputs[:, None]
79
+ )
80
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
81
+
82
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1])
83
+ class_tokens_mask = class_tokens_mask.view(-1)
84
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1])
85
+ # slice out the image token embeddings
86
+ image_token_embeds = prompt_embeds[class_tokens_mask]
87
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds)
88
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
89
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
90
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
91
+ return updated_prompt_embeds
92
+
93
+ class PhotoMakerIDEncoder(ldm_patched.modules.clip_model.CLIPVisionModelProjection):
94
+ def __init__(self):
95
+ self.load_device = ldm_patched.modules.model_management.text_encoder_device()
96
+ offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
97
+ dtype = ldm_patched.modules.model_management.text_encoder_dtype(self.load_device)
98
+
99
+ super().__init__(VISION_CONFIG_DICT, dtype, offload_device, ldm_patched.modules.ops.manual_cast)
100
+ self.visual_projection_2 = ldm_patched.modules.ops.manual_cast.Linear(1024, 1280, bias=False)
101
+ self.fuse_module = FuseModule(2048, ldm_patched.modules.ops.manual_cast)
102
+
103
+ def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask):
104
+ b, num_inputs, c, h, w = id_pixel_values.shape
105
+ id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w)
106
+
107
+ shared_id_embeds = self.vision_model(id_pixel_values)[2]
108
+ id_embeds = self.visual_projection(shared_id_embeds)
109
+ id_embeds_2 = self.visual_projection_2(shared_id_embeds)
110
+
111
+ id_embeds = id_embeds.view(b, num_inputs, 1, -1)
112
+ id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1)
113
+
114
+ id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1)
115
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask)
116
+
117
+ return updated_prompt_embeds
118
+
119
+
120
+ class PhotoMakerLoader:
121
+ @classmethod
122
+ def INPUT_TYPES(s):
123
+ return {"required": { "photomaker_model_name": (ldm_patched.utils.path_utils.get_filename_list("photomaker"), )}}
124
+
125
+ RETURN_TYPES = ("PHOTOMAKER",)
126
+ FUNCTION = "load_photomaker_model"
127
+
128
+ CATEGORY = "_for_testing/photomaker"
129
+
130
+ def load_photomaker_model(self, photomaker_model_name):
131
+ photomaker_model_path = ldm_patched.utils.path_utils.get_full_path("photomaker", photomaker_model_name)
132
+ photomaker_model = PhotoMakerIDEncoder()
133
+ data = ldm_patched.modules.utils.load_torch_file(photomaker_model_path, safe_load=True)
134
+ if "id_encoder" in data:
135
+ data = data["id_encoder"]
136
+ photomaker_model.load_state_dict(data)
137
+ return (photomaker_model,)
138
+
139
+
140
+ class PhotoMakerEncode:
141
+ @classmethod
142
+ def INPUT_TYPES(s):
143
+ return {"required": { "photomaker": ("PHOTOMAKER",),
144
+ "image": ("IMAGE",),
145
+ "clip": ("CLIP", ),
146
+ "text": ("STRING", {"multiline": True, "default": "photograph of photomaker"}),
147
+ }}
148
+
149
+ RETURN_TYPES = ("CONDITIONING",)
150
+ FUNCTION = "apply_photomaker"
151
+
152
+ CATEGORY = "_for_testing/photomaker"
153
+
154
+ def apply_photomaker(self, photomaker, image, clip, text):
155
+ special_token = "photomaker"
156
+ pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float()
157
+ try:
158
+ index = text.split(" ").index(special_token) + 1
159
+ except ValueError:
160
+ index = -1
161
+ tokens = clip.tokenize(text, return_word_ids=True)
162
+ out_tokens = {}
163
+ for k in tokens:
164
+ out_tokens[k] = []
165
+ for t in tokens[k]:
166
+ f = list(filter(lambda x: x[2] != index, t))
167
+ while len(f) < len(t):
168
+ f.append(t[-1])
169
+ out_tokens[k].append(f)
170
+
171
+ cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True)
172
+
173
+ if index > 0:
174
+ token_index = index - 1
175
+ num_id_images = 1
176
+ class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)]
177
+ out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device),
178
+ class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0))
179
+ else:
180
+ out = cond
181
+
182
+ return ([[out, {"pooled_output": pooled}]], )
183
+
184
+
185
+ NODE_CLASS_MAPPINGS = {
186
+ "PhotoMakerLoader": PhotoMakerLoader,
187
+ "PhotoMakerEncode": PhotoMakerEncode,
188
+ }
189
+
ldm_patched/contrib/external_post_processing.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+ import math
8
+
9
+ import ldm_patched.modules.utils
10
+
11
+
12
+ class Blend:
13
+ def __init__(self):
14
+ pass
15
+
16
+ @classmethod
17
+ def INPUT_TYPES(s):
18
+ return {
19
+ "required": {
20
+ "image1": ("IMAGE",),
21
+ "image2": ("IMAGE",),
22
+ "blend_factor": ("FLOAT", {
23
+ "default": 0.5,
24
+ "min": 0.0,
25
+ "max": 1.0,
26
+ "step": 0.01
27
+ }),
28
+ "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
29
+ },
30
+ }
31
+
32
+ RETURN_TYPES = ("IMAGE",)
33
+ FUNCTION = "blend_images"
34
+
35
+ CATEGORY = "image/postprocessing"
36
+
37
+ def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
38
+ image2 = image2.to(image1.device)
39
+ if image1.shape != image2.shape:
40
+ image2 = image2.permute(0, 3, 1, 2)
41
+ image2 = ldm_patched.modules.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')
42
+ image2 = image2.permute(0, 2, 3, 1)
43
+
44
+ blended_image = self.blend_mode(image1, image2, blend_mode)
45
+ blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
46
+ blended_image = torch.clamp(blended_image, 0, 1)
47
+ return (blended_image,)
48
+
49
+ def blend_mode(self, img1, img2, mode):
50
+ if mode == "normal":
51
+ return img2
52
+ elif mode == "multiply":
53
+ return img1 * img2
54
+ elif mode == "screen":
55
+ return 1 - (1 - img1) * (1 - img2)
56
+ elif mode == "overlay":
57
+ return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
58
+ elif mode == "soft_light":
59
+ return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
60
+ elif mode == "difference":
61
+ return img1 - img2
62
+ else:
63
+ raise ValueError(f"Unsupported blend mode: {mode}")
64
+
65
+ def g(self, x):
66
+ return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
67
+
68
+ def gaussian_kernel(kernel_size: int, sigma: float, device=None):
69
+ x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij")
70
+ d = torch.sqrt(x * x + y * y)
71
+ g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
72
+ return g / g.sum()
73
+
74
+ class Blur:
75
+ def __init__(self):
76
+ pass
77
+
78
+ @classmethod
79
+ def INPUT_TYPES(s):
80
+ return {
81
+ "required": {
82
+ "image": ("IMAGE",),
83
+ "blur_radius": ("INT", {
84
+ "default": 1,
85
+ "min": 1,
86
+ "max": 31,
87
+ "step": 1
88
+ }),
89
+ "sigma": ("FLOAT", {
90
+ "default": 1.0,
91
+ "min": 0.1,
92
+ "max": 10.0,
93
+ "step": 0.1
94
+ }),
95
+ },
96
+ }
97
+
98
+ RETURN_TYPES = ("IMAGE",)
99
+ FUNCTION = "blur"
100
+
101
+ CATEGORY = "image/postprocessing"
102
+
103
+ def blur(self, image: torch.Tensor, blur_radius: int, sigma: float):
104
+ if blur_radius == 0:
105
+ return (image,)
106
+
107
+ batch_size, height, width, channels = image.shape
108
+
109
+ kernel_size = blur_radius * 2 + 1
110
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1)
111
+
112
+ image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
113
+ padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect')
114
+ blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius]
115
+ blurred = blurred.permute(0, 2, 3, 1)
116
+
117
+ return (blurred,)
118
+
119
+ class Quantize:
120
+ def __init__(self):
121
+ pass
122
+
123
+ @classmethod
124
+ def INPUT_TYPES(s):
125
+ return {
126
+ "required": {
127
+ "image": ("IMAGE",),
128
+ "colors": ("INT", {
129
+ "default": 256,
130
+ "min": 1,
131
+ "max": 256,
132
+ "step": 1
133
+ }),
134
+ "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
135
+ },
136
+ }
137
+
138
+ RETURN_TYPES = ("IMAGE",)
139
+ FUNCTION = "quantize"
140
+
141
+ CATEGORY = "image/postprocessing"
142
+
143
+ def bayer(im, pal_im, order):
144
+ def normalized_bayer_matrix(n):
145
+ if n == 0:
146
+ return np.zeros((1,1), "float32")
147
+ else:
148
+ q = 4 ** n
149
+ m = q * normalized_bayer_matrix(n - 1)
150
+ return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
151
+
152
+ num_colors = len(pal_im.getpalette()) // 3
153
+ spread = 2 * 256 / num_colors
154
+ bayer_n = int(math.log2(order))
155
+ bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
156
+
157
+ result = torch.from_numpy(np.array(im).astype(np.float32))
158
+ tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
159
+ th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
160
+ tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
161
+ result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
162
+ result = result.to(dtype=torch.uint8)
163
+
164
+ im = Image.fromarray(result.cpu().numpy())
165
+ im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
166
+ return im
167
+
168
+ def quantize(self, image: torch.Tensor, colors: int, dither: str):
169
+ batch_size, height, width, _ = image.shape
170
+ result = torch.zeros_like(image)
171
+
172
+ for b in range(batch_size):
173
+ im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
174
+
175
+ pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
176
+
177
+ if dither == "none":
178
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
179
+ elif dither == "floyd-steinberg":
180
+ quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
181
+ elif dither.startswith("bayer"):
182
+ order = int(dither.split('-')[-1])
183
+ quantized_image = Quantize.bayer(im, pal_im, order)
184
+
185
+ quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
186
+ result[b] = quantized_array
187
+
188
+ return (result,)
189
+
190
+ class Sharpen:
191
+ def __init__(self):
192
+ pass
193
+
194
+ @classmethod
195
+ def INPUT_TYPES(s):
196
+ return {
197
+ "required": {
198
+ "image": ("IMAGE",),
199
+ "sharpen_radius": ("INT", {
200
+ "default": 1,
201
+ "min": 1,
202
+ "max": 31,
203
+ "step": 1
204
+ }),
205
+ "sigma": ("FLOAT", {
206
+ "default": 1.0,
207
+ "min": 0.1,
208
+ "max": 10.0,
209
+ "step": 0.1
210
+ }),
211
+ "alpha": ("FLOAT", {
212
+ "default": 1.0,
213
+ "min": 0.0,
214
+ "max": 5.0,
215
+ "step": 0.1
216
+ }),
217
+ },
218
+ }
219
+
220
+ RETURN_TYPES = ("IMAGE",)
221
+ FUNCTION = "sharpen"
222
+
223
+ CATEGORY = "image/postprocessing"
224
+
225
+ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float):
226
+ if sharpen_radius == 0:
227
+ return (image,)
228
+
229
+ batch_size, height, width, channels = image.shape
230
+
231
+ kernel_size = sharpen_radius * 2 + 1
232
+ kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
233
+ center = kernel_size // 2
234
+ kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
235
+ kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
236
+
237
+ tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
238
+ tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect')
239
+ sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
240
+ sharpened = sharpened.permute(0, 2, 3, 1)
241
+
242
+ result = torch.clamp(sharpened, 0, 1)
243
+
244
+ return (result,)
245
+
246
+ class ImageScaleToTotalPixels:
247
+ upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
248
+ crop_methods = ["disabled", "center"]
249
+
250
+ @classmethod
251
+ def INPUT_TYPES(s):
252
+ return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,),
253
+ "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}),
254
+ }}
255
+ RETURN_TYPES = ("IMAGE",)
256
+ FUNCTION = "upscale"
257
+
258
+ CATEGORY = "image/upscaling"
259
+
260
+ def upscale(self, image, upscale_method, megapixels):
261
+ samples = image.movedim(-1,1)
262
+ total = int(megapixels * 1024 * 1024)
263
+
264
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
265
+ width = round(samples.shape[3] * scale_by)
266
+ height = round(samples.shape[2] * scale_by)
267
+
268
+ s = ldm_patched.modules.utils.common_upscale(samples, width, height, upscale_method, "disabled")
269
+ s = s.movedim(1,-1)
270
+ return (s,)
271
+
272
+ NODE_CLASS_MAPPINGS = {
273
+ "ImageBlend": Blend,
274
+ "ImageBlur": Blur,
275
+ "ImageQuantize": Quantize,
276
+ "ImageSharpen": Sharpen,
277
+ "ImageScaleToTotalPixels": ImageScaleToTotalPixels,
278
+ }
ldm_patched/contrib/external_rebatch.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+
5
+ class LatentRebatch:
6
+ @classmethod
7
+ def INPUT_TYPES(s):
8
+ return {"required": { "latents": ("LATENT",),
9
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
10
+ }}
11
+ RETURN_TYPES = ("LATENT",)
12
+ INPUT_IS_LIST = True
13
+ OUTPUT_IS_LIST = (True, )
14
+
15
+ FUNCTION = "rebatch"
16
+
17
+ CATEGORY = "latent/batch"
18
+
19
+ @staticmethod
20
+ def get_batch(latents, list_ind, offset):
21
+ '''prepare a batch out of the list of latents'''
22
+ samples = latents[list_ind]['samples']
23
+ shape = samples.shape
24
+ mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
25
+ if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
26
+ torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
27
+ if mask.shape[0] < samples.shape[0]:
28
+ mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
29
+ if 'batch_index' in latents[list_ind]:
30
+ batch_inds = latents[list_ind]['batch_index']
31
+ else:
32
+ batch_inds = [x+offset for x in range(shape[0])]
33
+ return samples, mask, batch_inds
34
+
35
+ @staticmethod
36
+ def get_slices(indexable, num, batch_size):
37
+ '''divides an indexable object into num slices of length batch_size, and a remainder'''
38
+ slices = []
39
+ for i in range(num):
40
+ slices.append(indexable[i*batch_size:(i+1)*batch_size])
41
+ if num * batch_size < len(indexable):
42
+ return slices, indexable[num * batch_size:]
43
+ else:
44
+ return slices, None
45
+
46
+ @staticmethod
47
+ def slice_batch(batch, num, batch_size):
48
+ result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
49
+ return list(zip(*result))
50
+
51
+ @staticmethod
52
+ def cat_batch(batch1, batch2):
53
+ if batch1[0] is None:
54
+ return batch2
55
+ result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
56
+ return result
57
+
58
+ def rebatch(self, latents, batch_size):
59
+ batch_size = batch_size[0]
60
+
61
+ output_list = []
62
+ current_batch = (None, None, None)
63
+ processed = 0
64
+
65
+ for i in range(len(latents)):
66
+ # fetch new entry of list
67
+ #samples, masks, indices = self.get_batch(latents, i)
68
+ next_batch = self.get_batch(latents, i, processed)
69
+ processed += len(next_batch[2])
70
+ # set to current if current is None
71
+ if current_batch[0] is None:
72
+ current_batch = next_batch
73
+ # add previous to list if dimensions do not match
74
+ elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
75
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
76
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
77
+ current_batch = next_batch
78
+ # cat if everything checks out
79
+ else:
80
+ current_batch = self.cat_batch(current_batch, next_batch)
81
+
82
+ # add to list if dimensions gone above target batch size
83
+ if current_batch[0].shape[0] > batch_size:
84
+ num = current_batch[0].shape[0] // batch_size
85
+ sliced, remainder = self.slice_batch(current_batch, num, batch_size)
86
+
87
+ for i in range(num):
88
+ output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
89
+
90
+ current_batch = remainder
91
+
92
+ #add remainder
93
+ if current_batch[0] is not None:
94
+ sliced, _ = self.slice_batch(current_batch, 1, batch_size)
95
+ output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
96
+
97
+ #get rid of empty masks
98
+ for s in output_list:
99
+ if s['noise_mask'].mean() == 1.0:
100
+ del s['noise_mask']
101
+
102
+ return (output_list,)
103
+
104
+ class ImageRebatch:
105
+ @classmethod
106
+ def INPUT_TYPES(s):
107
+ return {"required": { "images": ("IMAGE",),
108
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
109
+ }}
110
+ RETURN_TYPES = ("IMAGE",)
111
+ INPUT_IS_LIST = True
112
+ OUTPUT_IS_LIST = (True, )
113
+
114
+ FUNCTION = "rebatch"
115
+
116
+ CATEGORY = "image/batch"
117
+
118
+ def rebatch(self, images, batch_size):
119
+ batch_size = batch_size[0]
120
+
121
+ output_list = []
122
+ all_images = []
123
+ for img in images:
124
+ for i in range(img.shape[0]):
125
+ all_images.append(img[i:i+1])
126
+
127
+ for i in range(0, len(all_images), batch_size):
128
+ output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))
129
+
130
+ return (output_list,)
131
+
132
+ NODE_CLASS_MAPPINGS = {
133
+ "RebatchLatents": LatentRebatch,
134
+ "RebatchImages": ImageRebatch,
135
+ }
136
+
137
+ NODE_DISPLAY_NAME_MAPPINGS = {
138
+ "RebatchLatents": "Rebatch Latents",
139
+ "RebatchImages": "Rebatch Images",
140
+ }
ldm_patched/contrib/external_sag.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ from torch import einsum
5
+ import torch.nn.functional as F
6
+ import math
7
+
8
+ from einops import rearrange, repeat
9
+ import os
10
+ from ldm_patched.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
11
+ import ldm_patched.modules.samplers
12
+
13
+ # from ldm_patched.modules/ldm/modules/attention.py
14
+ # but modified to return attention scores as well as output
15
+ def attention_basic_with_sim(q, k, v, heads, mask=None):
16
+ b, _, dim_head = q.shape
17
+ dim_head //= heads
18
+ scale = dim_head ** -0.5
19
+
20
+ h = heads
21
+ q, k, v = map(
22
+ lambda t: t.unsqueeze(3)
23
+ .reshape(b, -1, heads, dim_head)
24
+ .permute(0, 2, 1, 3)
25
+ .reshape(b * heads, -1, dim_head)
26
+ .contiguous(),
27
+ (q, k, v),
28
+ )
29
+
30
+ # force cast to fp32 to avoid overflowing
31
+ if _ATTN_PRECISION =="fp32":
32
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
33
+ else:
34
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
35
+
36
+ del q, k
37
+
38
+ if mask is not None:
39
+ mask = rearrange(mask, 'b ... -> b (...)')
40
+ max_neg_value = -torch.finfo(sim.dtype).max
41
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
42
+ sim.masked_fill_(~mask, max_neg_value)
43
+
44
+ # attention, what we cannot get enough of
45
+ sim = sim.softmax(dim=-1)
46
+
47
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
48
+ out = (
49
+ out.unsqueeze(0)
50
+ .reshape(b, heads, -1, dim_head)
51
+ .permute(0, 2, 1, 3)
52
+ .reshape(b, -1, heads * dim_head)
53
+ )
54
+ return (out, sim)
55
+
56
+ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
57
+ # reshape and GAP the attention map
58
+ _, hw1, hw2 = attn.shape
59
+ b, _, lh, lw = x0.shape
60
+ attn = attn.reshape(b, -1, hw1, hw2)
61
+ # Global Average Pool
62
+ mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
63
+ ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
64
+ mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
65
+
66
+ # Reshape
67
+ mask = (
68
+ mask.reshape(b, *mid_shape)
69
+ .unsqueeze(1)
70
+ .type(attn.dtype)
71
+ )
72
+ # Upsample
73
+ mask = F.interpolate(mask, (lh, lw))
74
+
75
+ blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma)
76
+ blurred = blurred * mask + x0 * (1 - mask)
77
+ return blurred
78
+
79
+ def gaussian_blur_2d(img, kernel_size, sigma):
80
+ ksize_half = (kernel_size - 1) * 0.5
81
+
82
+ x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
83
+
84
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
85
+
86
+ x_kernel = pdf / pdf.sum()
87
+ x_kernel = x_kernel.to(device=img.device, dtype=img.dtype)
88
+
89
+ kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :])
90
+ kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1])
91
+
92
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
93
+
94
+ img = F.pad(img, padding, mode="reflect")
95
+ img = F.conv2d(img, kernel2d, groups=img.shape[-3])
96
+ return img
97
+
98
+ class SelfAttentionGuidance:
99
+ @classmethod
100
+ def INPUT_TYPES(s):
101
+ return {"required": { "model": ("MODEL",),
102
+ "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
103
+ "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
104
+ }}
105
+ RETURN_TYPES = ("MODEL",)
106
+ FUNCTION = "patch"
107
+
108
+ CATEGORY = "_for_testing"
109
+
110
+ def patch(self, model, scale, blur_sigma):
111
+ m = model.clone()
112
+
113
+ attn_scores = None
114
+
115
+ # TODO: make this work properly with chunked batches
116
+ # currently, we can only save the attn from one UNet call
117
+ def attn_and_record(q, k, v, extra_options):
118
+ nonlocal attn_scores
119
+ # if uncond, save the attention scores
120
+ heads = extra_options["n_heads"]
121
+ cond_or_uncond = extra_options["cond_or_uncond"]
122
+ b = q.shape[0] // len(cond_or_uncond)
123
+ if 1 in cond_or_uncond:
124
+ uncond_index = cond_or_uncond.index(1)
125
+ # do the entire attention operation, but save the attention scores to attn_scores
126
+ (out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
127
+ # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
128
+ n_slices = heads * b
129
+ attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
130
+ return out
131
+ else:
132
+ return optimized_attention(q, k, v, heads=heads)
133
+
134
+ def post_cfg_function(args):
135
+ nonlocal attn_scores
136
+ uncond_attn = attn_scores
137
+
138
+ sag_scale = scale
139
+ sag_sigma = blur_sigma
140
+ sag_threshold = 1.0
141
+ model = args["model"]
142
+ uncond_pred = args["uncond_denoised"]
143
+ uncond = args["uncond"]
144
+ cfg_result = args["denoised"]
145
+ sigma = args["sigma"]
146
+ model_options = args["model_options"]
147
+ x = args["input"]
148
+ if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
149
+ return cfg_result
150
+
151
+ # create the adversarially blurred image
152
+ degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
153
+ degraded_noised = degraded + x - uncond_pred
154
+ # call into the UNet
155
+ (sag, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
156
+ return cfg_result + (degraded - sag) * sag_scale
157
+
158
+ m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)
159
+
160
+ # from diffusers:
161
+ # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
162
+ m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
163
+
164
+ return (m, )
165
+
166
+ NODE_CLASS_MAPPINGS = {
167
+ "SelfAttentionGuidance": SelfAttentionGuidance,
168
+ }
169
+
170
+ NODE_DISPLAY_NAME_MAPPINGS = {
171
+ "SelfAttentionGuidance": "Self-Attention Guidance",
172
+ }
ldm_patched/contrib/external_sdupscale.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.contrib.external
5
+ import ldm_patched.modules.utils
6
+
7
+ class SD_4XUpscale_Conditioning:
8
+ @classmethod
9
+ def INPUT_TYPES(s):
10
+ return {"required": { "images": ("IMAGE",),
11
+ "positive": ("CONDITIONING",),
12
+ "negative": ("CONDITIONING",),
13
+ "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}),
14
+ "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
15
+ }}
16
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
17
+ RETURN_NAMES = ("positive", "negative", "latent")
18
+
19
+ FUNCTION = "encode"
20
+
21
+ CATEGORY = "conditioning/upscale_diffusion"
22
+
23
+ def encode(self, images, positive, negative, scale_ratio, noise_augmentation):
24
+ width = max(1, round(images.shape[-2] * scale_ratio))
25
+ height = max(1, round(images.shape[-3] * scale_ratio))
26
+
27
+ pixels = ldm_patched.modules.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center")
28
+
29
+ out_cp = []
30
+ out_cn = []
31
+
32
+ for t in positive:
33
+ n = [t[0], t[1].copy()]
34
+ n[1]['concat_image'] = pixels
35
+ n[1]['noise_augmentation'] = noise_augmentation
36
+ out_cp.append(n)
37
+
38
+ for t in negative:
39
+ n = [t[0], t[1].copy()]
40
+ n[1]['concat_image'] = pixels
41
+ n[1]['noise_augmentation'] = noise_augmentation
42
+ out_cn.append(n)
43
+
44
+ latent = torch.zeros([images.shape[0], 4, height // 4, width // 4])
45
+ return (out_cp, out_cn, {"samples":latent})
46
+
47
+ NODE_CLASS_MAPPINGS = {
48
+ "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning,
49
+ }
ldm_patched/contrib/external_stable3d.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import torch
4
+ import ldm_patched.contrib.external
5
+ import ldm_patched.modules.utils
6
+
7
+ def camera_embeddings(elevation, azimuth):
8
+ elevation = torch.as_tensor([elevation])
9
+ azimuth = torch.as_tensor([azimuth])
10
+ embeddings = torch.stack(
11
+ [
12
+ torch.deg2rad(
13
+ (90 - elevation) - (90)
14
+ ), # Zero123 polar is 90-elevation
15
+ torch.sin(torch.deg2rad(azimuth)),
16
+ torch.cos(torch.deg2rad(azimuth)),
17
+ torch.deg2rad(
18
+ 90 - torch.full_like(elevation, 0)
19
+ ),
20
+ ], dim=-1).unsqueeze(1)
21
+
22
+ return embeddings
23
+
24
+
25
+ class StableZero123_Conditioning:
26
+ @classmethod
27
+ def INPUT_TYPES(s):
28
+ return {"required": { "clip_vision": ("CLIP_VISION",),
29
+ "init_image": ("IMAGE",),
30
+ "vae": ("VAE",),
31
+ "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
32
+ "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
33
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
34
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
35
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
36
+ }}
37
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
38
+ RETURN_NAMES = ("positive", "negative", "latent")
39
+
40
+ FUNCTION = "encode"
41
+
42
+ CATEGORY = "conditioning/3d_models"
43
+
44
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth):
45
+ output = clip_vision.encode_image(init_image)
46
+ pooled = output.image_embeds.unsqueeze(0)
47
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
48
+ encode_pixels = pixels[:,:,:,:3]
49
+ t = vae.encode(encode_pixels)
50
+ cam_embeds = camera_embeddings(elevation, azimuth)
51
+ cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1)
52
+
53
+ positive = [[cond, {"concat_latent_image": t}]]
54
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
55
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
56
+ return (positive, negative, {"samples":latent})
57
+
58
+ class StableZero123_Conditioning_Batched:
59
+ @classmethod
60
+ def INPUT_TYPES(s):
61
+ return {"required": { "clip_vision": ("CLIP_VISION",),
62
+ "init_image": ("IMAGE",),
63
+ "vae": ("VAE",),
64
+ "width": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
65
+ "height": ("INT", {"default": 256, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
66
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
67
+ "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
68
+ "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
69
+ "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
70
+ "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0}),
71
+ }}
72
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
73
+ RETURN_NAMES = ("positive", "negative", "latent")
74
+
75
+ FUNCTION = "encode"
76
+
77
+ CATEGORY = "conditioning/3d_models"
78
+
79
+ def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment):
80
+ output = clip_vision.encode_image(init_image)
81
+ pooled = output.image_embeds.unsqueeze(0)
82
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
83
+ encode_pixels = pixels[:,:,:,:3]
84
+ t = vae.encode(encode_pixels)
85
+
86
+ cam_embeds = []
87
+ for i in range(batch_size):
88
+ cam_embeds.append(camera_embeddings(elevation, azimuth))
89
+ elevation += elevation_batch_increment
90
+ azimuth += azimuth_batch_increment
91
+
92
+ cam_embeds = torch.cat(cam_embeds, dim=0)
93
+ cond = torch.cat([ldm_patched.modules.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1)
94
+
95
+ positive = [[cond, {"concat_latent_image": t}]]
96
+ negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]]
97
+ latent = torch.zeros([batch_size, 4, height // 8, width // 8])
98
+ return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size})
99
+
100
+
101
+ NODE_CLASS_MAPPINGS = {
102
+ "StableZero123_Conditioning": StableZero123_Conditioning,
103
+ "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched,
104
+ }
ldm_patched/contrib/external_tomesd.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ #Taken from: https://github.com/dbolya/tomesd
4
+
5
+ import torch
6
+ from typing import Tuple, Callable
7
+ import math
8
+
9
+ def do_nothing(x: torch.Tensor, mode:str=None):
10
+ return x
11
+
12
+
13
+ def mps_gather_workaround(input, dim, index):
14
+ if input.shape[-1] == 1:
15
+ return torch.gather(
16
+ input.unsqueeze(-1),
17
+ dim - 1 if dim < 0 else dim,
18
+ index.unsqueeze(-1)
19
+ ).squeeze(-1)
20
+ else:
21
+ return torch.gather(input, dim, index)
22
+
23
+
24
+ def bipartite_soft_matching_random2d(metric: torch.Tensor,
25
+ w: int, h: int, sx: int, sy: int, r: int,
26
+ no_rand: bool = False) -> Tuple[Callable, Callable]:
27
+ """
28
+ Partitions the tokens into src and dst and merges r tokens from src to dst.
29
+ Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
30
+ Args:
31
+ - metric [B, N, C]: metric to use for similarity
32
+ - w: image width in tokens
33
+ - h: image height in tokens
34
+ - sx: stride in the x dimension for dst, must divide w
35
+ - sy: stride in the y dimension for dst, must divide h
36
+ - r: number of tokens to remove (by merging)
37
+ - no_rand: if true, disable randomness (use top left corner only)
38
+ """
39
+ B, N, _ = metric.shape
40
+
41
+ if r <= 0 or w == 1 or h == 1:
42
+ return do_nothing, do_nothing
43
+
44
+ gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
45
+
46
+ with torch.no_grad():
47
+
48
+ hsy, wsx = h // sy, w // sx
49
+
50
+ # For each sy by sx kernel, randomly assign one token to be dst and the rest src
51
+ if no_rand:
52
+ rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
53
+ else:
54
+ rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
55
+
56
+ # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
57
+ idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
58
+ idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
59
+ idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
60
+
61
+ # Image is not divisible by sx or sy so we need to move it into a new buffer
62
+ if (hsy * sy) < h or (wsx * sx) < w:
63
+ idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
64
+ idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
65
+ else:
66
+ idx_buffer = idx_buffer_view
67
+
68
+ # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
69
+ rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
70
+
71
+ # We're finished with these
72
+ del idx_buffer, idx_buffer_view
73
+
74
+ # rand_idx is currently dst|src, so split them
75
+ num_dst = hsy * wsx
76
+ a_idx = rand_idx[:, num_dst:, :] # src
77
+ b_idx = rand_idx[:, :num_dst, :] # dst
78
+
79
+ def split(x):
80
+ C = x.shape[-1]
81
+ src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
82
+ dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
83
+ return src, dst
84
+
85
+ # Cosine similarity between A and B
86
+ metric = metric / metric.norm(dim=-1, keepdim=True)
87
+ a, b = split(metric)
88
+ scores = a @ b.transpose(-1, -2)
89
+
90
+ # Can't reduce more than the # tokens in src
91
+ r = min(a.shape[1], r)
92
+
93
+ # Find the most similar greedily
94
+ node_max, node_idx = scores.max(dim=-1)
95
+ edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
96
+
97
+ unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
98
+ src_idx = edge_idx[..., :r, :] # Merged Tokens
99
+ dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
100
+
101
+ def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
102
+ src, dst = split(x)
103
+ n, t1, c = src.shape
104
+
105
+ unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
106
+ src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
107
+ dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
108
+
109
+ return torch.cat([unm, dst], dim=1)
110
+
111
+ def unmerge(x: torch.Tensor) -> torch.Tensor:
112
+ unm_len = unm_idx.shape[1]
113
+ unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
114
+ _, _, c = unm.shape
115
+
116
+ src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
117
+
118
+ # Combine back to the original shape
119
+ out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
120
+ out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
121
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
122
+ out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
123
+
124
+ return out
125
+
126
+ return merge, unmerge
127
+
128
+
129
+ def get_functions(x, ratio, original_shape):
130
+ b, c, original_h, original_w = original_shape
131
+ original_tokens = original_h * original_w
132
+ downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
133
+ stride_x = 2
134
+ stride_y = 2
135
+ max_downsample = 1
136
+
137
+ if downsample <= max_downsample:
138
+ w = int(math.ceil(original_w / downsample))
139
+ h = int(math.ceil(original_h / downsample))
140
+ r = int(x.shape[1] * ratio)
141
+ no_rand = False
142
+ m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
143
+ return m, u
144
+
145
+ nothing = lambda y: y
146
+ return nothing, nothing
147
+
148
+
149
+
150
+ class TomePatchModel:
151
+ @classmethod
152
+ def INPUT_TYPES(s):
153
+ return {"required": { "model": ("MODEL",),
154
+ "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
155
+ }}
156
+ RETURN_TYPES = ("MODEL",)
157
+ FUNCTION = "patch"
158
+
159
+ CATEGORY = "_for_testing"
160
+
161
+ def patch(self, model, ratio):
162
+ self.u = None
163
+ def tomesd_m(q, k, v, extra_options):
164
+ #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
165
+ #however from my basic testing it seems that using q instead gives better results
166
+ m, self.u = get_functions(q, ratio, extra_options["original_shape"])
167
+ return m(q), k, v
168
+ def tomesd_u(n, extra_options):
169
+ return self.u(n)
170
+
171
+ m = model.clone()
172
+ m.set_model_attn1_patch(tomesd_m)
173
+ m.set_model_attn1_output_patch(tomesd_u)
174
+ return (m, )
175
+
176
+
177
+ NODE_CLASS_MAPPINGS = {
178
+ "TomePatchModel": TomePatchModel,
179
+ }
ldm_patched/contrib/external_upscale_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import os
4
+ from ldm_patched.pfn import model_loading
5
+ from ldm_patched.modules import model_management
6
+ import torch
7
+ import ldm_patched.modules.utils
8
+ import ldm_patched.utils.path_utils
9
+
10
+ class UpscaleModelLoader:
11
+ @classmethod
12
+ def INPUT_TYPES(s):
13
+ return {"required": { "model_name": (ldm_patched.utils.path_utils.get_filename_list("upscale_models"), ),
14
+ }}
15
+ RETURN_TYPES = ("UPSCALE_MODEL",)
16
+ FUNCTION = "load_model"
17
+
18
+ CATEGORY = "loaders"
19
+
20
+ def load_model(self, model_name):
21
+ model_path = ldm_patched.utils.path_utils.get_full_path("upscale_models", model_name)
22
+ sd = ldm_patched.modules.utils.load_torch_file(model_path, safe_load=True)
23
+ if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
24
+ sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"module.":""})
25
+ out = model_loading.load_state_dict(sd).eval()
26
+ return (out, )
27
+
28
+
29
+ class ImageUpscaleWithModel:
30
+ @classmethod
31
+ def INPUT_TYPES(s):
32
+ return {"required": { "upscale_model": ("UPSCALE_MODEL",),
33
+ "image": ("IMAGE",),
34
+ }}
35
+ RETURN_TYPES = ("IMAGE",)
36
+ FUNCTION = "upscale"
37
+
38
+ CATEGORY = "image/upscaling"
39
+
40
+ def upscale(self, upscale_model, image):
41
+ device = model_management.get_torch_device()
42
+ upscale_model.to(device)
43
+ in_img = image.movedim(-1,-3).to(device)
44
+ free_memory = model_management.get_free_memory(device)
45
+
46
+ tile = 512
47
+ overlap = 32
48
+
49
+ oom = True
50
+ while oom:
51
+ try:
52
+ steps = in_img.shape[0] * ldm_patched.modules.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
53
+ pbar = ldm_patched.modules.utils.ProgressBar(steps)
54
+ s = ldm_patched.modules.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
55
+ oom = False
56
+ except model_management.OOM_EXCEPTION as e:
57
+ tile //= 2
58
+ if tile < 128:
59
+ raise e
60
+
61
+ upscale_model.cpu()
62
+ s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
63
+ return (s,)
64
+
65
+ NODE_CLASS_MAPPINGS = {
66
+ "UpscaleModelLoader": UpscaleModelLoader,
67
+ "ImageUpscaleWithModel": ImageUpscaleWithModel
68
+ }
ldm_patched/contrib/external_video_model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py
2
+
3
+ import ldm_patched.contrib.external
4
+ import torch
5
+ import ldm_patched.modules.utils
6
+ import ldm_patched.modules.sd
7
+ import ldm_patched.utils.path_utils
8
+ import ldm_patched.contrib.external_model_merging
9
+
10
+
11
+ class ImageOnlyCheckpointLoader:
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {"required": { "ckpt_name": (ldm_patched.utils.path_utils.get_filename_list("checkpoints"), ),
15
+ }}
16
+ RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
17
+ FUNCTION = "load_checkpoint"
18
+
19
+ CATEGORY = "loaders/video_models"
20
+
21
+ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
22
+ ckpt_path = ldm_patched.utils.path_utils.get_full_path("checkpoints", ckpt_name)
23
+ out = ldm_patched.modules.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=ldm_patched.utils.path_utils.get_folder_paths("embeddings"))
24
+ return (out[0], out[3], out[2])
25
+
26
+
27
+ class SVD_img2vid_Conditioning:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {"required": { "clip_vision": ("CLIP_VISION",),
31
+ "init_image": ("IMAGE",),
32
+ "vae": ("VAE",),
33
+ "width": ("INT", {"default": 1024, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
34
+ "height": ("INT", {"default": 576, "min": 16, "max": ldm_patched.contrib.external.MAX_RESOLUTION, "step": 8}),
35
+ "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
36
+ "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
37
+ "fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
38
+ "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
39
+ }}
40
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
41
+ RETURN_NAMES = ("positive", "negative", "latent")
42
+
43
+ FUNCTION = "encode"
44
+
45
+ CATEGORY = "conditioning/video_models"
46
+
47
+ def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
48
+ output = clip_vision.encode_image(init_image)
49
+ pooled = output.image_embeds.unsqueeze(0)
50
+ pixels = ldm_patched.modules.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
51
+ encode_pixels = pixels[:,:,:,:3]
52
+ if augmentation_level > 0:
53
+ encode_pixels += torch.randn_like(pixels) * augmentation_level
54
+ t = vae.encode(encode_pixels)
55
+ positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
56
+ negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
57
+ latent = torch.zeros([video_frames, 4, height // 8, width // 8])
58
+ return (positive, negative, {"samples":latent})
59
+
60
+ class VideoLinearCFGGuidance:
61
+ @classmethod
62
+ def INPUT_TYPES(s):
63
+ return {"required": { "model": ("MODEL",),
64
+ "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
65
+ }}
66
+ RETURN_TYPES = ("MODEL",)
67
+ FUNCTION = "patch"
68
+
69
+ CATEGORY = "sampling/video_models"
70
+
71
+ def patch(self, model, min_cfg):
72
+ def linear_cfg(args):
73
+ cond = args["cond"]
74
+ uncond = args["uncond"]
75
+ cond_scale = args["cond_scale"]
76
+
77
+ scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
78
+ return uncond + scale * (cond - uncond)
79
+
80
+ m = model.clone()
81
+ m.set_model_sampler_cfg_function(linear_cfg)
82
+ return (m, )
83
+
84
+ class ImageOnlyCheckpointSave(ldm_patched.contrib.external_model_merging.CheckpointSave):
85
+ CATEGORY = "_for_testing"
86
+
87
+ @classmethod
88
+ def INPUT_TYPES(s):
89
+ return {"required": { "model": ("MODEL",),
90
+ "clip_vision": ("CLIP_VISION",),
91
+ "vae": ("VAE",),
92
+ "filename_prefix": ("STRING", {"default": "checkpoints/ldm_patched"}),},
93
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
94
+
95
+ def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None):
96
+ ldm_patched.contrib.external_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
97
+ return {}
98
+
99
+ NODE_CLASS_MAPPINGS = {
100
+ "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
101
+ "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning,
102
+ "VideoLinearCFGGuidance": VideoLinearCFGGuidance,
103
+ "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave,
104
+ }
105
+
106
+ NODE_DISPLAY_NAME_MAPPINGS = {
107
+ "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
108
+ }
ldm_patched/controlnet/cldm.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ldm_patched.ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ldm_patched.ldm.modules.attention import SpatialTransformer
14
+ from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ldm_patched.ldm.util import exists
16
+ import ldm_patched.modules.ops
17
+
18
+ class ControlledUnetModel(UNetModel):
19
+ #implemented in the ldm unet
20
+ pass
21
+
22
+ class ControlNet(nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_size,
26
+ in_channels,
27
+ model_channels,
28
+ hint_channels,
29
+ num_res_blocks,
30
+ dropout=0,
31
+ channel_mult=(1, 2, 4, 8),
32
+ conv_resample=True,
33
+ dims=2,
34
+ num_classes=None,
35
+ use_checkpoint=False,
36
+ dtype=torch.float32,
37
+ num_heads=-1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ resblock_updown=False,
42
+ use_new_attention_order=False,
43
+ use_spatial_transformer=False, # custom transformer support
44
+ transformer_depth=1, # custom transformer support
45
+ context_dim=None, # custom transformer support
46
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
+ legacy=True,
48
+ disable_self_attentions=None,
49
+ num_attention_blocks=None,
50
+ disable_middle_self_attn=False,
51
+ use_linear_in_transformer=False,
52
+ adm_in_channels=None,
53
+ transformer_depth_middle=None,
54
+ transformer_depth_output=None,
55
+ device=None,
56
+ operations=ldm_patched.modules.ops.disable_weight_init,
57
+ **kwargs,
58
+ ):
59
+ super().__init__()
60
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
61
+ if use_spatial_transformer:
62
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
63
+
64
+ if context_dim is not None:
65
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
66
+ # from omegaconf.listconfig import ListConfig
67
+ # if type(context_dim) == ListConfig:
68
+ # context_dim = list(context_dim)
69
+
70
+ if num_heads_upsample == -1:
71
+ num_heads_upsample = num_heads
72
+
73
+ if num_heads == -1:
74
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
75
+
76
+ if num_head_channels == -1:
77
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
78
+
79
+ self.dims = dims
80
+ self.image_size = image_size
81
+ self.in_channels = in_channels
82
+ self.model_channels = model_channels
83
+
84
+ if isinstance(num_res_blocks, int):
85
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
86
+ else:
87
+ if len(num_res_blocks) != len(channel_mult):
88
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
89
+ "as a list/tuple (per-level) with the same length as channel_mult")
90
+ self.num_res_blocks = num_res_blocks
91
+
92
+ if disable_self_attentions is not None:
93
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
94
+ assert len(disable_self_attentions) == len(channel_mult)
95
+ if num_attention_blocks is not None:
96
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
97
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
98
+
99
+ transformer_depth = transformer_depth[:]
100
+
101
+ self.dropout = dropout
102
+ self.channel_mult = channel_mult
103
+ self.conv_resample = conv_resample
104
+ self.num_classes = num_classes
105
+ self.use_checkpoint = use_checkpoint
106
+ self.dtype = dtype
107
+ self.num_heads = num_heads
108
+ self.num_head_channels = num_head_channels
109
+ self.num_heads_upsample = num_heads_upsample
110
+ self.predict_codebook_ids = n_embed is not None
111
+
112
+ time_embed_dim = model_channels * 4
113
+ self.time_embed = nn.Sequential(
114
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
115
+ nn.SiLU(),
116
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
117
+ )
118
+
119
+ if self.num_classes is not None:
120
+ if isinstance(self.num_classes, int):
121
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
122
+ elif self.num_classes == "continuous":
123
+ print("setting up linear c_adm embedding layer")
124
+ self.label_emb = nn.Linear(1, time_embed_dim)
125
+ elif self.num_classes == "sequential":
126
+ assert adm_in_channels is not None
127
+ self.label_emb = nn.Sequential(
128
+ nn.Sequential(
129
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
130
+ nn.SiLU(),
131
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
132
+ )
133
+ )
134
+ else:
135
+ raise ValueError()
136
+
137
+ self.input_blocks = nn.ModuleList(
138
+ [
139
+ TimestepEmbedSequential(
140
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
141
+ )
142
+ ]
143
+ )
144
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
145
+
146
+ self.input_hint_block = TimestepEmbedSequential(
147
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
148
+ nn.SiLU(),
149
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
150
+ nn.SiLU(),
151
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
152
+ nn.SiLU(),
153
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
154
+ nn.SiLU(),
155
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
156
+ nn.SiLU(),
157
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
158
+ nn.SiLU(),
159
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
160
+ nn.SiLU(),
161
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
162
+ )
163
+
164
+ self._feature_size = model_channels
165
+ input_block_chans = [model_channels]
166
+ ch = model_channels
167
+ ds = 1
168
+ for level, mult in enumerate(channel_mult):
169
+ for nr in range(self.num_res_blocks[level]):
170
+ layers = [
171
+ ResBlock(
172
+ ch,
173
+ time_embed_dim,
174
+ dropout,
175
+ out_channels=mult * model_channels,
176
+ dims=dims,
177
+ use_checkpoint=use_checkpoint,
178
+ use_scale_shift_norm=use_scale_shift_norm,
179
+ dtype=self.dtype,
180
+ device=device,
181
+ operations=operations,
182
+ )
183
+ ]
184
+ ch = mult * model_channels
185
+ num_transformers = transformer_depth.pop(0)
186
+ if num_transformers > 0:
187
+ if num_head_channels == -1:
188
+ dim_head = ch // num_heads
189
+ else:
190
+ num_heads = ch // num_head_channels
191
+ dim_head = num_head_channels
192
+ if legacy:
193
+ #num_heads = 1
194
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
195
+ if exists(disable_self_attentions):
196
+ disabled_sa = disable_self_attentions[level]
197
+ else:
198
+ disabled_sa = False
199
+
200
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
201
+ layers.append(
202
+ SpatialTransformer(
203
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
204
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
205
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
206
+ )
207
+ )
208
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
209
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
210
+ self._feature_size += ch
211
+ input_block_chans.append(ch)
212
+ if level != len(channel_mult) - 1:
213
+ out_ch = ch
214
+ self.input_blocks.append(
215
+ TimestepEmbedSequential(
216
+ ResBlock(
217
+ ch,
218
+ time_embed_dim,
219
+ dropout,
220
+ out_channels=out_ch,
221
+ dims=dims,
222
+ use_checkpoint=use_checkpoint,
223
+ use_scale_shift_norm=use_scale_shift_norm,
224
+ down=True,
225
+ dtype=self.dtype,
226
+ device=device,
227
+ operations=operations
228
+ )
229
+ if resblock_updown
230
+ else Downsample(
231
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
232
+ )
233
+ )
234
+ )
235
+ ch = out_ch
236
+ input_block_chans.append(ch)
237
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
238
+ ds *= 2
239
+ self._feature_size += ch
240
+
241
+ if num_head_channels == -1:
242
+ dim_head = ch // num_heads
243
+ else:
244
+ num_heads = ch // num_head_channels
245
+ dim_head = num_head_channels
246
+ if legacy:
247
+ #num_heads = 1
248
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
249
+ mid_block = [
250
+ ResBlock(
251
+ ch,
252
+ time_embed_dim,
253
+ dropout,
254
+ dims=dims,
255
+ use_checkpoint=use_checkpoint,
256
+ use_scale_shift_norm=use_scale_shift_norm,
257
+ dtype=self.dtype,
258
+ device=device,
259
+ operations=operations
260
+ )]
261
+ if transformer_depth_middle >= 0:
262
+ mid_block += [SpatialTransformer( # always uses a self-attn
263
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
264
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
265
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
266
+ ),
267
+ ResBlock(
268
+ ch,
269
+ time_embed_dim,
270
+ dropout,
271
+ dims=dims,
272
+ use_checkpoint=use_checkpoint,
273
+ use_scale_shift_norm=use_scale_shift_norm,
274
+ dtype=self.dtype,
275
+ device=device,
276
+ operations=operations
277
+ )]
278
+ self.middle_block = TimestepEmbedSequential(*mid_block)
279
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
280
+ self._feature_size += ch
281
+
282
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
283
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
284
+
285
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
286
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
287
+ emb = self.time_embed(t_emb)
288
+
289
+ guided_hint = self.input_hint_block(hint, emb, context)
290
+
291
+ outs = []
292
+
293
+ hs = []
294
+ if self.num_classes is not None:
295
+ assert y.shape[0] == x.shape[0]
296
+ emb = emb + self.label_emb(y)
297
+
298
+ h = x
299
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
300
+ if guided_hint is not None:
301
+ h = module(h, emb, context)
302
+ h += guided_hint
303
+ guided_hint = None
304
+ else:
305
+ h = module(h, emb, context)
306
+ outs.append(zero_conv(h, emb, context))
307
+
308
+ h = self.middle_block(h, emb, context)
309
+ outs.append(self.middle_block_out(h, emb, context))
310
+
311
+ return outs
312
+
ldm_patched/k_diffusion/sampling.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ import torchsde
7
+ from tqdm.auto import trange, tqdm
8
+
9
+ from . import utils
10
+
11
+
12
+ def append_zero(x):
13
+ return torch.cat([x, x.new_zeros([1])])
14
+
15
+
16
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
17
+ """Constructs the noise schedule of Karras et al. (2022)."""
18
+ ramp = torch.linspace(0, 1, n, device=device)
19
+ min_inv_rho = sigma_min ** (1 / rho)
20
+ max_inv_rho = sigma_max ** (1 / rho)
21
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
22
+ return append_zero(sigmas).to(device)
23
+
24
+
25
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
26
+ """Constructs an exponential noise schedule."""
27
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
28
+ return append_zero(sigmas)
29
+
30
+
31
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
32
+ """Constructs an polynomial in log sigma noise schedule."""
33
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
34
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
35
+ return append_zero(sigmas)
36
+
37
+
38
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
39
+ """Constructs a continuous VP noise schedule."""
40
+ t = torch.linspace(1, eps_s, n, device=device)
41
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
42
+ return append_zero(sigmas)
43
+
44
+
45
+ def to_d(x, sigma, denoised):
46
+ """Converts a denoiser output to a Karras ODE derivative."""
47
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
48
+
49
+
50
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
51
+ """Calculates the noise level (sigma_down) to step down to and the amount
52
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
53
+ if not eta:
54
+ return sigma_to, 0.
55
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
56
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
57
+ return sigma_down, sigma_up
58
+
59
+
60
+ def default_noise_sampler(x):
61
+ return lambda sigma, sigma_next: torch.randn_like(x)
62
+
63
+
64
+ class BatchedBrownianTree:
65
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
66
+
67
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
68
+ self.cpu_tree = True
69
+ if "cpu" in kwargs:
70
+ self.cpu_tree = kwargs.pop("cpu")
71
+ t0, t1, self.sign = self.sort(t0, t1)
72
+ w0 = kwargs.get('w0', torch.zeros_like(x))
73
+ if seed is None:
74
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
75
+ self.batched = True
76
+ try:
77
+ assert len(seed) == x.shape[0]
78
+ w0 = w0[0]
79
+ except TypeError:
80
+ seed = [seed]
81
+ self.batched = False
82
+ if self.cpu_tree:
83
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
84
+ else:
85
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
86
+
87
+ @staticmethod
88
+ def sort(a, b):
89
+ return (a, b, 1) if a < b else (b, a, -1)
90
+
91
+ def __call__(self, t0, t1):
92
+ t0, t1, sign = self.sort(t0, t1)
93
+ if self.cpu_tree:
94
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
95
+ else:
96
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
97
+
98
+ return w if self.batched else w[0]
99
+
100
+
101
+ class BrownianTreeNoiseSampler:
102
+ """A noise sampler backed by a torchsde.BrownianTree.
103
+
104
+ Args:
105
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
106
+ random samples.
107
+ sigma_min (float): The low end of the valid interval.
108
+ sigma_max (float): The high end of the valid interval.
109
+ seed (int or List[int]): The random seed. If a list of seeds is
110
+ supplied instead of a single integer, then the noise sampler will
111
+ use one BrownianTree per batch item, each with its own seed.
112
+ transform (callable): A function that maps sigma to the sampler's
113
+ internal timestep.
114
+ """
115
+
116
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
117
+ self.transform = transform
118
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
119
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
120
+
121
+ def __call__(self, sigma, sigma_next):
122
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
123
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
124
+
125
+
126
+ @torch.no_grad()
127
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
128
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
129
+ extra_args = {} if extra_args is None else extra_args
130
+ s_in = x.new_ones([x.shape[0]])
131
+ for i in trange(len(sigmas) - 1, disable=disable):
132
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
133
+ sigma_hat = sigmas[i] * (gamma + 1)
134
+ if gamma > 0:
135
+ eps = torch.randn_like(x) * s_noise
136
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
137
+ denoised = model(x, sigma_hat * s_in, **extra_args)
138
+ d = to_d(x, sigma_hat, denoised)
139
+ if callback is not None:
140
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
141
+ dt = sigmas[i + 1] - sigma_hat
142
+ # Euler method
143
+ x = x + d * dt
144
+ return x
145
+
146
+
147
+ @torch.no_grad()
148
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
149
+ """Ancestral sampling with Euler method steps."""
150
+ extra_args = {} if extra_args is None else extra_args
151
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
152
+ s_in = x.new_ones([x.shape[0]])
153
+ for i in trange(len(sigmas) - 1, disable=disable):
154
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
155
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
156
+ if callback is not None:
157
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
158
+ d = to_d(x, sigmas[i], denoised)
159
+ # Euler method
160
+ dt = sigma_down - sigmas[i]
161
+ x = x + d * dt
162
+ if sigmas[i + 1] > 0:
163
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
164
+ return x
165
+
166
+
167
+ @torch.no_grad()
168
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
169
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
170
+ extra_args = {} if extra_args is None else extra_args
171
+ s_in = x.new_ones([x.shape[0]])
172
+ for i in trange(len(sigmas) - 1, disable=disable):
173
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
174
+ sigma_hat = sigmas[i] * (gamma + 1)
175
+ if gamma > 0:
176
+ eps = torch.randn_like(x) * s_noise
177
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
178
+ denoised = model(x, sigma_hat * s_in, **extra_args)
179
+ d = to_d(x, sigma_hat, denoised)
180
+ if callback is not None:
181
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
182
+ dt = sigmas[i + 1] - sigma_hat
183
+ if sigmas[i + 1] == 0:
184
+ # Euler method
185
+ x = x + d * dt
186
+ else:
187
+ # Heun's method
188
+ x_2 = x + d * dt
189
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
190
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
191
+ d_prime = (d + d_2) / 2
192
+ x = x + d_prime * dt
193
+ return x
194
+
195
+
196
+ @torch.no_grad()
197
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
198
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
199
+ extra_args = {} if extra_args is None else extra_args
200
+ s_in = x.new_ones([x.shape[0]])
201
+ for i in trange(len(sigmas) - 1, disable=disable):
202
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
203
+ sigma_hat = sigmas[i] * (gamma + 1)
204
+ if gamma > 0:
205
+ eps = torch.randn_like(x) * s_noise
206
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
207
+ denoised = model(x, sigma_hat * s_in, **extra_args)
208
+ d = to_d(x, sigma_hat, denoised)
209
+ if callback is not None:
210
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
211
+ if sigmas[i + 1] == 0:
212
+ # Euler method
213
+ dt = sigmas[i + 1] - sigma_hat
214
+ x = x + d * dt
215
+ else:
216
+ # DPM-Solver-2
217
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
218
+ dt_1 = sigma_mid - sigma_hat
219
+ dt_2 = sigmas[i + 1] - sigma_hat
220
+ x_2 = x + d * dt_1
221
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
222
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
223
+ x = x + d_2 * dt_2
224
+ return x
225
+
226
+
227
+ @torch.no_grad()
228
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
229
+ """Ancestral sampling with DPM-Solver second-order steps."""
230
+ extra_args = {} if extra_args is None else extra_args
231
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
232
+ s_in = x.new_ones([x.shape[0]])
233
+ for i in trange(len(sigmas) - 1, disable=disable):
234
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
235
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
236
+ if callback is not None:
237
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
238
+ d = to_d(x, sigmas[i], denoised)
239
+ if sigma_down == 0:
240
+ # Euler method
241
+ dt = sigma_down - sigmas[i]
242
+ x = x + d * dt
243
+ else:
244
+ # DPM-Solver-2
245
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
246
+ dt_1 = sigma_mid - sigmas[i]
247
+ dt_2 = sigma_down - sigmas[i]
248
+ x_2 = x + d * dt_1
249
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
250
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
251
+ x = x + d_2 * dt_2
252
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
253
+ return x
254
+
255
+
256
+ def linear_multistep_coeff(order, t, i, j):
257
+ if order - 1 > i:
258
+ raise ValueError(f'Order {order} too high for step {i}')
259
+ def fn(tau):
260
+ prod = 1.
261
+ for k in range(order):
262
+ if j == k:
263
+ continue
264
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
265
+ return prod
266
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
267
+
268
+
269
+ @torch.no_grad()
270
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
271
+ extra_args = {} if extra_args is None else extra_args
272
+ s_in = x.new_ones([x.shape[0]])
273
+ sigmas_cpu = sigmas.detach().cpu().numpy()
274
+ ds = []
275
+ for i in trange(len(sigmas) - 1, disable=disable):
276
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
277
+ d = to_d(x, sigmas[i], denoised)
278
+ ds.append(d)
279
+ if len(ds) > order:
280
+ ds.pop(0)
281
+ if callback is not None:
282
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
283
+ cur_order = min(i + 1, order)
284
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
285
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
286
+ return x
287
+
288
+
289
+ class PIDStepSizeController:
290
+ """A PID controller for ODE adaptive step size control."""
291
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
292
+ self.h = h
293
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
294
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
295
+ self.b3 = dcoeff / order
296
+ self.accept_safety = accept_safety
297
+ self.eps = eps
298
+ self.errs = []
299
+
300
+ def limiter(self, x):
301
+ return 1 + math.atan(x - 1)
302
+
303
+ def propose_step(self, error):
304
+ inv_error = 1 / (float(error) + self.eps)
305
+ if not self.errs:
306
+ self.errs = [inv_error, inv_error, inv_error]
307
+ self.errs[0] = inv_error
308
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
309
+ factor = self.limiter(factor)
310
+ accept = factor >= self.accept_safety
311
+ if accept:
312
+ self.errs[2] = self.errs[1]
313
+ self.errs[1] = self.errs[0]
314
+ self.h *= factor
315
+ return accept
316
+
317
+
318
+ class DPMSolver(nn.Module):
319
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
320
+
321
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
322
+ super().__init__()
323
+ self.model = model
324
+ self.extra_args = {} if extra_args is None else extra_args
325
+ self.eps_callback = eps_callback
326
+ self.info_callback = info_callback
327
+
328
+ def t(self, sigma):
329
+ return -sigma.log()
330
+
331
+ def sigma(self, t):
332
+ return t.neg().exp()
333
+
334
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
335
+ if key in eps_cache:
336
+ return eps_cache[key], eps_cache
337
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
338
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
339
+ if self.eps_callback is not None:
340
+ self.eps_callback()
341
+ return eps, {key: eps, **eps_cache}
342
+
343
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
344
+ eps_cache = {} if eps_cache is None else eps_cache
345
+ h = t_next - t
346
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
347
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
348
+ return x_1, eps_cache
349
+
350
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
351
+ eps_cache = {} if eps_cache is None else eps_cache
352
+ h = t_next - t
353
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
354
+ s1 = t + r1 * h
355
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
356
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
357
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
358
+ return x_2, eps_cache
359
+
360
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
361
+ eps_cache = {} if eps_cache is None else eps_cache
362
+ h = t_next - t
363
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
364
+ s1 = t + r1 * h
365
+ s2 = t + r2 * h
366
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
367
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
368
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
369
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
370
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
371
+ return x_3, eps_cache
372
+
373
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
374
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
375
+ if not t_end > t_start and eta:
376
+ raise ValueError('eta must be 0 for reverse sampling')
377
+
378
+ m = math.floor(nfe / 3) + 1
379
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
380
+
381
+ if nfe % 3 == 0:
382
+ orders = [3] * (m - 2) + [2, 1]
383
+ else:
384
+ orders = [3] * (m - 1) + [nfe % 3]
385
+
386
+ for i in range(len(orders)):
387
+ eps_cache = {}
388
+ t, t_next = ts[i], ts[i + 1]
389
+ if eta:
390
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
391
+ t_next_ = torch.minimum(t_end, self.t(sd))
392
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
393
+ else:
394
+ t_next_, su = t_next, 0.
395
+
396
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
397
+ denoised = x - self.sigma(t) * eps
398
+ if self.info_callback is not None:
399
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
400
+
401
+ if orders[i] == 1:
402
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
403
+ elif orders[i] == 2:
404
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
405
+ else:
406
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
407
+
408
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
409
+
410
+ return x
411
+
412
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
413
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
414
+ if order not in {2, 3}:
415
+ raise ValueError('order should be 2 or 3')
416
+ forward = t_end > t_start
417
+ if not forward and eta:
418
+ raise ValueError('eta must be 0 for reverse sampling')
419
+ h_init = abs(h_init) * (1 if forward else -1)
420
+ atol = torch.tensor(atol)
421
+ rtol = torch.tensor(rtol)
422
+ s = t_start
423
+ x_prev = x
424
+ accept = True
425
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
426
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
427
+
428
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
429
+ eps_cache = {}
430
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
431
+ if eta:
432
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
433
+ t_ = torch.minimum(t_end, self.t(sd))
434
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
435
+ else:
436
+ t_, su = t, 0.
437
+
438
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
439
+ denoised = x - self.sigma(s) * eps
440
+
441
+ if order == 2:
442
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
443
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
444
+ else:
445
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
446
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
447
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
448
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
449
+ accept = pid.propose_step(error)
450
+ if accept:
451
+ x_prev = x_low
452
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
453
+ s = t
454
+ info['n_accept'] += 1
455
+ else:
456
+ info['n_reject'] += 1
457
+ info['nfe'] += order
458
+ info['steps'] += 1
459
+
460
+ if self.info_callback is not None:
461
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
462
+
463
+ return x, info
464
+
465
+
466
+ @torch.no_grad()
467
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
468
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
469
+ if sigma_min <= 0 or sigma_max <= 0:
470
+ raise ValueError('sigma_min and sigma_max must not be 0')
471
+ with tqdm(total=n, disable=disable) as pbar:
472
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
473
+ if callback is not None:
474
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
475
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
476
+
477
+
478
+ @torch.no_grad()
479
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
480
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
481
+ if sigma_min <= 0 or sigma_max <= 0:
482
+ raise ValueError('sigma_min and sigma_max must not be 0')
483
+ with tqdm(disable=disable) as pbar:
484
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
485
+ if callback is not None:
486
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
487
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
488
+ if return_info:
489
+ return x, info
490
+ return x
491
+
492
+
493
+ @torch.no_grad()
494
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
495
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
496
+ extra_args = {} if extra_args is None else extra_args
497
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
498
+ s_in = x.new_ones([x.shape[0]])
499
+ sigma_fn = lambda t: t.neg().exp()
500
+ t_fn = lambda sigma: sigma.log().neg()
501
+
502
+ for i in trange(len(sigmas) - 1, disable=disable):
503
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
504
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
505
+ if callback is not None:
506
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
507
+ if sigma_down == 0:
508
+ # Euler method
509
+ d = to_d(x, sigmas[i], denoised)
510
+ dt = sigma_down - sigmas[i]
511
+ x = x + d * dt
512
+ else:
513
+ # DPM-Solver++(2S)
514
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
515
+ r = 1 / 2
516
+ h = t_next - t
517
+ s = t + r * h
518
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
519
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
520
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
521
+ # Noise addition
522
+ if sigmas[i + 1] > 0:
523
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
524
+ return x
525
+
526
+
527
+ @torch.no_grad()
528
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
529
+ """DPM-Solver++ (stochastic)."""
530
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
531
+ seed = extra_args.get("seed", None)
532
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
533
+ extra_args = {} if extra_args is None else extra_args
534
+ s_in = x.new_ones([x.shape[0]])
535
+ sigma_fn = lambda t: t.neg().exp()
536
+ t_fn = lambda sigma: sigma.log().neg()
537
+
538
+ for i in trange(len(sigmas) - 1, disable=disable):
539
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
540
+ if callback is not None:
541
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
542
+ if sigmas[i + 1] == 0:
543
+ # Euler method
544
+ d = to_d(x, sigmas[i], denoised)
545
+ dt = sigmas[i + 1] - sigmas[i]
546
+ x = x + d * dt
547
+ else:
548
+ # DPM-Solver++
549
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
550
+ h = t_next - t
551
+ s = t + h * r
552
+ fac = 1 / (2 * r)
553
+
554
+ # Step 1
555
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
556
+ s_ = t_fn(sd)
557
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
558
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
559
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
560
+
561
+ # Step 2
562
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
563
+ t_next_ = t_fn(sd)
564
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
565
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
566
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
567
+ return x
568
+
569
+
570
+ @torch.no_grad()
571
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
572
+ """DPM-Solver++(2M)."""
573
+ extra_args = {} if extra_args is None else extra_args
574
+ s_in = x.new_ones([x.shape[0]])
575
+ sigma_fn = lambda t: t.neg().exp()
576
+ t_fn = lambda sigma: sigma.log().neg()
577
+ old_denoised = None
578
+
579
+ for i in trange(len(sigmas) - 1, disable=disable):
580
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
581
+ if callback is not None:
582
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
583
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
584
+ h = t_next - t
585
+ if old_denoised is None or sigmas[i + 1] == 0:
586
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
587
+ else:
588
+ h_last = t - t_fn(sigmas[i - 1])
589
+ r = h_last / h
590
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
591
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
592
+ old_denoised = denoised
593
+ return x
594
+
595
+ @torch.no_grad()
596
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
597
+ """DPM-Solver++(2M) SDE."""
598
+
599
+ if solver_type not in {'heun', 'midpoint'}:
600
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
601
+
602
+ seed = extra_args.get("seed", None)
603
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
604
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
605
+ extra_args = {} if extra_args is None else extra_args
606
+ s_in = x.new_ones([x.shape[0]])
607
+
608
+ old_denoised = None
609
+ h_last = None
610
+ h = None
611
+
612
+ for i in trange(len(sigmas) - 1, disable=disable):
613
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
614
+ if callback is not None:
615
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
616
+ if sigmas[i + 1] == 0:
617
+ # Denoising step
618
+ x = denoised
619
+ else:
620
+ # DPM-Solver++(2M) SDE
621
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
622
+ h = s - t
623
+ eta_h = eta * h
624
+
625
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
626
+
627
+ if old_denoised is not None:
628
+ r = h_last / h
629
+ if solver_type == 'heun':
630
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
631
+ elif solver_type == 'midpoint':
632
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
633
+
634
+ if eta:
635
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
636
+
637
+ old_denoised = denoised
638
+ h_last = h
639
+ return x
640
+
641
+ @torch.no_grad()
642
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
643
+ """DPM-Solver++(3M) SDE."""
644
+
645
+ seed = extra_args.get("seed", None)
646
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
647
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
648
+ extra_args = {} if extra_args is None else extra_args
649
+ s_in = x.new_ones([x.shape[0]])
650
+
651
+ denoised_1, denoised_2 = None, None
652
+ h, h_1, h_2 = None, None, None
653
+
654
+ for i in trange(len(sigmas) - 1, disable=disable):
655
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
656
+ if callback is not None:
657
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
658
+ if sigmas[i + 1] == 0:
659
+ # Denoising step
660
+ x = denoised
661
+ else:
662
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
663
+ h = s - t
664
+ h_eta = h * (eta + 1)
665
+
666
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
667
+
668
+ if h_2 is not None:
669
+ r0 = h_1 / h
670
+ r1 = h_2 / h
671
+ d1_0 = (denoised - denoised_1) / r0
672
+ d1_1 = (denoised_1 - denoised_2) / r1
673
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
674
+ d2 = (d1_0 - d1_1) / (r0 + r1)
675
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
676
+ phi_3 = phi_2 / h_eta - 0.5
677
+ x = x + phi_2 * d1 - phi_3 * d2
678
+ elif h_1 is not None:
679
+ r = h_1 / h
680
+ d = (denoised - denoised_1) / r
681
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
682
+ x = x + phi_2 * d
683
+
684
+ if eta:
685
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
686
+
687
+ denoised_1, denoised_2 = denoised, denoised_1
688
+ h_1, h_2 = h, h_1
689
+ return x
690
+
691
+ @torch.no_grad()
692
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
693
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
694
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
695
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
696
+
697
+ @torch.no_grad()
698
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
699
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
700
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
701
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
702
+
703
+ @torch.no_grad()
704
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
705
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
706
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
707
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
708
+
709
+
710
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
711
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
712
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
713
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
714
+
715
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
716
+ if sigma_prev > 0:
717
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
718
+ return mu
719
+
720
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
721
+ extra_args = {} if extra_args is None else extra_args
722
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
723
+ s_in = x.new_ones([x.shape[0]])
724
+
725
+ for i in trange(len(sigmas) - 1, disable=disable):
726
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
727
+ if callback is not None:
728
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
729
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
730
+ if sigmas[i + 1] != 0:
731
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
732
+ return x
733
+
734
+
735
+ @torch.no_grad()
736
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
737
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
738
+
739
+ @torch.no_grad()
740
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
741
+ extra_args = {} if extra_args is None else extra_args
742
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
743
+ s_in = x.new_ones([x.shape[0]])
744
+ for i in trange(len(sigmas) - 1, disable=disable):
745
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
746
+ if callback is not None:
747
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
748
+
749
+ x = denoised
750
+ if sigmas[i + 1] > 0:
751
+ x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
752
+ return x
753
+
754
+
755
+ @torch.no_grad()
756
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
757
+ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
758
+ extra_args = {} if extra_args is None else extra_args
759
+ s_in = x.new_ones([x.shape[0]])
760
+ s_end = sigmas[-1]
761
+ for i in trange(len(sigmas) - 1, disable=disable):
762
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
763
+ eps = torch.randn_like(x) * s_noise
764
+ sigma_hat = sigmas[i] * (gamma + 1)
765
+ if gamma > 0:
766
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
767
+ denoised = model(x, sigma_hat * s_in, **extra_args)
768
+ d = to_d(x, sigma_hat, denoised)
769
+ if callback is not None:
770
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
771
+ dt = sigmas[i + 1] - sigma_hat
772
+ if sigmas[i + 1] == s_end:
773
+ # Euler method
774
+ x = x + d * dt
775
+ elif sigmas[i + 2] == s_end:
776
+
777
+ # Heun's method
778
+ x_2 = x + d * dt
779
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
780
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
781
+
782
+ w = 2 * sigmas[0]
783
+ w2 = sigmas[i+1]/w
784
+ w1 = 1 - w2
785
+
786
+ d_prime = d * w1 + d_2 * w2
787
+
788
+
789
+ x = x + d_prime * dt
790
+
791
+ else:
792
+ # Heun++
793
+ x_2 = x + d * dt
794
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
795
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
796
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
797
+
798
+ x_3 = x_2 + d_2 * dt_2
799
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
800
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
801
+
802
+ w = 3 * sigmas[0]
803
+ w2 = sigmas[i + 1] / w
804
+ w3 = sigmas[i + 2] / w
805
+ w1 = 1 - w2 - w3
806
+
807
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
808
+ x = x + d_prime * dt
809
+ return x
810
+
811
+
812
+ @torch.no_grad()
813
+ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3):
814
+ extra_args = {} if extra_args is None else extra_args
815
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
816
+ s_in = x.new_ones([x.shape[0]])
817
+
818
+ model_sampling = model.inner_model.inner_model.model_sampling
819
+ timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu()
820
+ timesteps_s[-1] = 0
821
+ alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s]
822
+ beta_prod_s = 1 - alpha_prod_s
823
+ for i in trange(len(sigmas) - 1, disable=disable):
824
+ denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample
825
+ eps = (x - denoised) / sigmas[i]
826
+ denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps
827
+
828
+ if callback is not None:
829
+ callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
830
+
831
+ x = denoised
832
+ if eta > 0 and sigmas[i + 1] > 0:
833
+ noise = noise_sampler(sigmas[i], sigmas[i + 1])
834
+ x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
835
+ else:
836
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2)
837
+
838
+ return x
839
+
840
+
841
+ @torch.no_grad()
842
+ def sample_restart(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
843
+ """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
844
+ Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
845
+ If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
846
+ """
847
+ extra_args = {} if extra_args is None else extra_args
848
+ s_in = x.new_ones([x.shape[0]])
849
+ step_id = 0
850
+
851
+ def heun_step(x, old_sigma, new_sigma, second_order=True):
852
+ nonlocal step_id
853
+ denoised = model(x, old_sigma * s_in, **extra_args)
854
+ d = to_d(x, old_sigma, denoised)
855
+ if callback is not None:
856
+ callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
857
+ dt = new_sigma - old_sigma
858
+ if new_sigma == 0 or not second_order:
859
+ # Euler method
860
+ x = x + d * dt
861
+ else:
862
+ # Heun's method
863
+ x_2 = x + d * dt
864
+ denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
865
+ d_2 = to_d(x_2, new_sigma, denoised_2)
866
+ d_prime = (d + d_2) / 2
867
+ x = x + d_prime * dt
868
+ step_id += 1
869
+ return x
870
+
871
+ steps = sigmas.shape[0] - 1
872
+ if restart_list is None:
873
+ if steps >= 20:
874
+ restart_steps = 9
875
+ restart_times = 1
876
+ if steps >= 36:
877
+ restart_steps = steps // 4
878
+ restart_times = 2
879
+ sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
880
+ restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
881
+ else:
882
+ restart_list = {}
883
+
884
+ restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
885
+
886
+ step_list = []
887
+ for i in range(len(sigmas) - 1):
888
+ step_list.append((sigmas[i], sigmas[i + 1]))
889
+ if i + 1 in restart_list:
890
+ restart_steps, restart_times, restart_max = restart_list[i + 1]
891
+ min_idx = i + 1
892
+ max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
893
+ if max_idx < min_idx:
894
+ sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
895
+ while restart_times > 0:
896
+ restart_times -= 1
897
+ step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))
898
+
899
+ last_sigma = None
900
+ for old_sigma, new_sigma in tqdm(step_list, disable=disable):
901
+ if last_sigma is None:
902
+ last_sigma = old_sigma
903
+ elif last_sigma < old_sigma:
904
+ x = x + torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
905
+ x = heun_step(x, old_sigma, new_sigma)
906
+ last_sigma = new_sigma
907
+
908
+ return x
ldm_patched/k_diffusion/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import hashlib
3
+ import math
4
+ from pathlib import Path
5
+ import shutil
6
+ import urllib
7
+ import warnings
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from torch import nn, optim
12
+ from torch.utils import data
13
+
14
+
15
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
16
+ """Apply passed in transforms for HuggingFace Datasets."""
17
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
18
+ return {image_key: images}
19
+
20
+
21
+ def append_dims(x, target_dims):
22
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
23
+ dims_to_append = target_dims - x.ndim
24
+ if dims_to_append < 0:
25
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
26
+ expanded = x[(...,) + (None,) * dims_to_append]
27
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
28
+ # https://github.com/pytorch/pytorch/issues/84364
29
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
30
+
31
+
32
+ def n_params(module):
33
+ """Returns the number of trainable parameters in a module."""
34
+ return sum(p.numel() for p in module.parameters())
35
+
36
+
37
+ def download_file(path, url, digest=None):
38
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
39
+ path = Path(path)
40
+ path.parent.mkdir(parents=True, exist_ok=True)
41
+ if not path.exists():
42
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
43
+ shutil.copyfileobj(response, f)
44
+ if digest is not None:
45
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
46
+ if digest != file_digest:
47
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
48
+ return path
49
+
50
+
51
+ @contextmanager
52
+ def train_mode(model, mode=True):
53
+ """A context manager that places a model into training mode and restores
54
+ the previous mode on exit."""
55
+ modes = [module.training for module in model.modules()]
56
+ try:
57
+ yield model.train(mode)
58
+ finally:
59
+ for i, module in enumerate(model.modules()):
60
+ module.training = modes[i]
61
+
62
+
63
+ def eval_mode(model):
64
+ """A context manager that places a model into evaluation mode and restores
65
+ the previous mode on exit."""
66
+ return train_mode(model, False)
67
+
68
+
69
+ @torch.no_grad()
70
+ def ema_update(model, averaged_model, decay):
71
+ """Incorporates updated model parameters into an exponential moving averaged
72
+ version of a model. It should be called after each optimizer step."""
73
+ model_params = dict(model.named_parameters())
74
+ averaged_params = dict(averaged_model.named_parameters())
75
+ assert model_params.keys() == averaged_params.keys()
76
+
77
+ for name, param in model_params.items():
78
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
79
+
80
+ model_buffers = dict(model.named_buffers())
81
+ averaged_buffers = dict(averaged_model.named_buffers())
82
+ assert model_buffers.keys() == averaged_buffers.keys()
83
+
84
+ for name, buf in model_buffers.items():
85
+ averaged_buffers[name].copy_(buf)
86
+
87
+
88
+ class EMAWarmup:
89
+ """Implements an EMA warmup using an inverse decay schedule.
90
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
91
+ good values for models you plan to train for a million or more steps (reaches decay
92
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
93
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
94
+ 215.4k steps).
95
+ Args:
96
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
97
+ power (float): Exponential factor of EMA warmup. Default: 1.
98
+ min_value (float): The minimum EMA decay rate. Default: 0.
99
+ max_value (float): The maximum EMA decay rate. Default: 1.
100
+ start_at (int): The epoch to start averaging at. Default: 0.
101
+ last_epoch (int): The index of last epoch. Default: 0.
102
+ """
103
+
104
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
105
+ last_epoch=0):
106
+ self.inv_gamma = inv_gamma
107
+ self.power = power
108
+ self.min_value = min_value
109
+ self.max_value = max_value
110
+ self.start_at = start_at
111
+ self.last_epoch = last_epoch
112
+
113
+ def state_dict(self):
114
+ """Returns the state of the class as a :class:`dict`."""
115
+ return dict(self.__dict__.items())
116
+
117
+ def load_state_dict(self, state_dict):
118
+ """Loads the class's state.
119
+ Args:
120
+ state_dict (dict): scaler state. Should be an object returned
121
+ from a call to :meth:`state_dict`.
122
+ """
123
+ self.__dict__.update(state_dict)
124
+
125
+ def get_value(self):
126
+ """Gets the current EMA decay rate."""
127
+ epoch = max(0, self.last_epoch - self.start_at)
128
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
129
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
130
+
131
+ def step(self):
132
+ """Updates the step count."""
133
+ self.last_epoch += 1
134
+
135
+
136
+ class InverseLR(optim.lr_scheduler._LRScheduler):
137
+ """Implements an inverse decay learning rate schedule with an optional exponential
138
+ warmup. When last_epoch=-1, sets initial lr as lr.
139
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
140
+ (1 / 2)**power of its original value.
141
+ Args:
142
+ optimizer (Optimizer): Wrapped optimizer.
143
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
144
+ power (float): Exponential factor of learning rate decay. Default: 1.
145
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
146
+ Default: 0.
147
+ min_lr (float): The minimum learning rate. Default: 0.
148
+ last_epoch (int): The index of last epoch. Default: -1.
149
+ verbose (bool): If ``True``, prints a message to stdout for
150
+ each update. Default: ``False``.
151
+ """
152
+
153
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
154
+ last_epoch=-1, verbose=False):
155
+ self.inv_gamma = inv_gamma
156
+ self.power = power
157
+ if not 0. <= warmup < 1:
158
+ raise ValueError('Invalid value for warmup')
159
+ self.warmup = warmup
160
+ self.min_lr = min_lr
161
+ super().__init__(optimizer, last_epoch, verbose)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ warnings.warn("To get the last learning rate computed by the scheduler, "
166
+ "please use `get_last_lr()`.")
167
+
168
+ return self._get_closed_form_lr()
169
+
170
+ def _get_closed_form_lr(self):
171
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
172
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
173
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
174
+ for base_lr in self.base_lrs]
175
+
176
+
177
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
178
+ """Implements an exponential learning rate schedule with an optional exponential
179
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
180
+ continuously by decay (default 0.5) every num_steps steps.
181
+ Args:
182
+ optimizer (Optimizer): Wrapped optimizer.
183
+ num_steps (float): The number of steps to decay the learning rate by decay in.
184
+ decay (float): The factor by which to decay the learning rate every num_steps
185
+ steps. Default: 0.5.
186
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
187
+ Default: 0.
188
+ min_lr (float): The minimum learning rate. Default: 0.
189
+ last_epoch (int): The index of last epoch. Default: -1.
190
+ verbose (bool): If ``True``, prints a message to stdout for
191
+ each update. Default: ``False``.
192
+ """
193
+
194
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
195
+ last_epoch=-1, verbose=False):
196
+ self.num_steps = num_steps
197
+ self.decay = decay
198
+ if not 0. <= warmup < 1:
199
+ raise ValueError('Invalid value for warmup')
200
+ self.warmup = warmup
201
+ self.min_lr = min_lr
202
+ super().__init__(optimizer, last_epoch, verbose)
203
+
204
+ def get_lr(self):
205
+ if not self._get_lr_called_within_step:
206
+ warnings.warn("To get the last learning rate computed by the scheduler, "
207
+ "please use `get_last_lr()`.")
208
+
209
+ return self._get_closed_form_lr()
210
+
211
+ def _get_closed_form_lr(self):
212
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
213
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
214
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
215
+ for base_lr in self.base_lrs]
216
+
217
+
218
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
219
+ """Draws samples from an lognormal distribution."""
220
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
221
+
222
+
223
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
224
+ """Draws samples from an optionally truncated log-logistic distribution."""
225
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
226
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
227
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
228
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
229
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
230
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
231
+
232
+
233
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
234
+ """Draws samples from an log-uniform distribution."""
235
+ min_value = math.log(min_value)
236
+ max_value = math.log(max_value)
237
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
238
+
239
+
240
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
241
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
242
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
243
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
244
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
245
+ return torch.tan(u * math.pi / 2) * sigma_data
246
+
247
+
248
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
249
+ """Draws samples from a split lognormal distribution."""
250
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
251
+ u = torch.rand(shape, device=device, dtype=dtype)
252
+ n_left = n * -scale_1 + loc
253
+ n_right = n * scale_2 + loc
254
+ ratio = scale_1 / (scale_1 + scale_2)
255
+ return torch.where(u < ratio, n_left, n_right).exp()
256
+
257
+
258
+ class FolderOfImages(data.Dataset):
259
+ """Recursively finds all images in a directory. It does not support
260
+ classes/targets."""
261
+
262
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
263
+
264
+ def __init__(self, root, transform=None):
265
+ super().__init__()
266
+ self.root = Path(root)
267
+ self.transform = nn.Identity() if transform is None else transform
268
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
269
+
270
+ def __repr__(self):
271
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
272
+
273
+ def __len__(self):
274
+ return len(self.paths)
275
+
276
+ def __getitem__(self, key):
277
+ path = self.paths[key]
278
+ with open(path, 'rb') as f:
279
+ image = Image.open(f).convert('RGB')
280
+ image = self.transform(image)
281
+ return image,
282
+
283
+
284
+ class CSVLogger:
285
+ def __init__(self, filename, columns):
286
+ self.filename = Path(filename)
287
+ self.columns = columns
288
+ if self.filename.exists():
289
+ self.file = open(self.filename, 'a')
290
+ else:
291
+ self.file = open(self.filename, 'w')
292
+ self.write(*self.columns)
293
+
294
+ def write(self, *args):
295
+ print(*args, sep=',', file=self.file, flush=True)
296
+
297
+
298
+ @contextmanager
299
+ def tf32_mode(cudnn=None, matmul=None):
300
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
301
+ cudnn_old = torch.backends.cudnn.allow_tf32
302
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
303
+ try:
304
+ if cudnn is not None:
305
+ torch.backends.cudnn.allow_tf32 = cudnn
306
+ if matmul is not None:
307
+ torch.backends.cuda.matmul.allow_tf32 = matmul
308
+ yield
309
+ finally:
310
+ if cudnn is not None:
311
+ torch.backends.cudnn.allow_tf32 = cudnn_old
312
+ if matmul is not None:
313
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
ldm_patched/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ # import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ from ldm_patched.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
+
9
+ from ldm_patched.ldm.util import instantiate_from_config
10
+ from ldm_patched.ldm.modules.ema import LitEma
11
+ import ldm_patched.modules.ops
12
+
13
+ class DiagonalGaussianRegularizer(torch.nn.Module):
14
+ def __init__(self, sample: bool = True):
15
+ super().__init__()
16
+ self.sample = sample
17
+
18
+ def get_trainable_parameters(self) -> Any:
19
+ yield from ()
20
+
21
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
22
+ log = dict()
23
+ posterior = DiagonalGaussianDistribution(z)
24
+ if self.sample:
25
+ z = posterior.sample()
26
+ else:
27
+ z = posterior.mode()
28
+ kl_loss = posterior.kl()
29
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
30
+ log["kl_loss"] = kl_loss
31
+ return z, log
32
+
33
+
34
+ class AbstractAutoencoder(torch.nn.Module):
35
+ """
36
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
37
+ unCLIP models, etc. Hence, it is fairly general, and specific features
38
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ ema_decay: Union[None, float] = None,
44
+ monitor: Union[None, str] = None,
45
+ input_key: str = "jpg",
46
+ **kwargs,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.input_key = input_key
51
+ self.use_ema = ema_decay is not None
52
+ if monitor is not None:
53
+ self.monitor = monitor
54
+
55
+ if self.use_ema:
56
+ self.model_ema = LitEma(self, decay=ema_decay)
57
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
58
+
59
+ def get_input(self, batch) -> Any:
60
+ raise NotImplementedError()
61
+
62
+ def on_train_batch_end(self, *args, **kwargs):
63
+ # for EMA computation
64
+ if self.use_ema:
65
+ self.model_ema(self)
66
+
67
+ @contextmanager
68
+ def ema_scope(self, context=None):
69
+ if self.use_ema:
70
+ self.model_ema.store(self.parameters())
71
+ self.model_ema.copy_to(self)
72
+ if context is not None:
73
+ logpy.info(f"{context}: Switched to EMA weights")
74
+ try:
75
+ yield None
76
+ finally:
77
+ if self.use_ema:
78
+ self.model_ema.restore(self.parameters())
79
+ if context is not None:
80
+ logpy.info(f"{context}: Restored training weights")
81
+
82
+ def encode(self, *args, **kwargs) -> torch.Tensor:
83
+ raise NotImplementedError("encode()-method of abstract base class called")
84
+
85
+ def decode(self, *args, **kwargs) -> torch.Tensor:
86
+ raise NotImplementedError("decode()-method of abstract base class called")
87
+
88
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
89
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
90
+ return get_obj_from_str(cfg["target"])(
91
+ params, lr=lr, **cfg.get("params", dict())
92
+ )
93
+
94
+ def configure_optimizers(self) -> Any:
95
+ raise NotImplementedError()
96
+
97
+
98
+ class AutoencodingEngine(AbstractAutoencoder):
99
+ """
100
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
101
+ (we also restore them explicitly as special cases for legacy reasons).
102
+ Regularizations such as KL or VQ are moved to the regularizer class.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ *args,
108
+ encoder_config: Dict,
109
+ decoder_config: Dict,
110
+ regularizer_config: Dict,
111
+ **kwargs,
112
+ ):
113
+ super().__init__(*args, **kwargs)
114
+
115
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
116
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
117
+ self.regularization: AbstractRegularizer = instantiate_from_config(
118
+ regularizer_config
119
+ )
120
+
121
+ def get_last_layer(self):
122
+ return self.decoder.get_last_layer()
123
+
124
+ def encode(
125
+ self,
126
+ x: torch.Tensor,
127
+ return_reg_log: bool = False,
128
+ unregularized: bool = False,
129
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
130
+ z = self.encoder(x)
131
+ if unregularized:
132
+ return z, dict()
133
+ z, reg_log = self.regularization(z)
134
+ if return_reg_log:
135
+ return z, reg_log
136
+ return z
137
+
138
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
139
+ x = self.decoder(z, **kwargs)
140
+ return x
141
+
142
+ def forward(
143
+ self, x: torch.Tensor, **additional_decode_kwargs
144
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
145
+ z, reg_log = self.encode(x, return_reg_log=True)
146
+ dec = self.decode(z, **additional_decode_kwargs)
147
+ return z, dec, reg_log
148
+
149
+
150
+ class AutoencodingEngineLegacy(AutoencodingEngine):
151
+ def __init__(self, embed_dim: int, **kwargs):
152
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
153
+ ddconfig = kwargs.pop("ddconfig")
154
+ super().__init__(
155
+ encoder_config={
156
+ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Encoder",
157
+ "params": ddconfig,
158
+ },
159
+ decoder_config={
160
+ "target": "ldm_patched.ldm.modules.diffusionmodules.model.Decoder",
161
+ "params": ddconfig,
162
+ },
163
+ **kwargs,
164
+ )
165
+ self.quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(
166
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
167
+ (1 + ddconfig["double_z"]) * embed_dim,
168
+ 1,
169
+ )
170
+ self.post_quant_conv = ldm_patched.modules.ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
171
+ self.embed_dim = embed_dim
172
+
173
+ def get_autoencoder_params(self) -> list:
174
+ params = super().get_autoencoder_params()
175
+ return params
176
+
177
+ def encode(
178
+ self, x: torch.Tensor, return_reg_log: bool = False
179
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
180
+ if self.max_batch_size is None:
181
+ z = self.encoder(x)
182
+ z = self.quant_conv(z)
183
+ else:
184
+ N = x.shape[0]
185
+ bs = self.max_batch_size
186
+ n_batches = int(math.ceil(N / bs))
187
+ z = list()
188
+ for i_batch in range(n_batches):
189
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
190
+ z_batch = self.quant_conv(z_batch)
191
+ z.append(z_batch)
192
+ z = torch.cat(z, 0)
193
+
194
+ z, reg_log = self.regularization(z)
195
+ if return_reg_log:
196
+ return z, reg_log
197
+ return z
198
+
199
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
200
+ if self.max_batch_size is None:
201
+ dec = self.post_quant_conv(z)
202
+ dec = self.decoder(dec, **decoder_kwargs)
203
+ else:
204
+ N = z.shape[0]
205
+ bs = self.max_batch_size
206
+ n_batches = int(math.ceil(N / bs))
207
+ dec = list()
208
+ for i_batch in range(n_batches):
209
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
210
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
211
+ dec.append(dec_batch)
212
+ dec = torch.cat(dec, 0)
213
+
214
+ return dec
215
+
216
+
217
+ class AutoencoderKL(AutoencodingEngineLegacy):
218
+ def __init__(self, **kwargs):
219
+ if "lossconfig" in kwargs:
220
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
221
+ super().__init__(
222
+ regularizer_config={
223
+ "target": (
224
+ "ldm_patched.ldm.models.autoencoder.DiagonalGaussianRegularizer"
225
+ )
226
+ },
227
+ **kwargs,
228
+ )
ldm_patched/ldm/modules/attention.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn, einsum
5
+ from einops import rearrange, repeat
6
+ from typing import Optional, Any
7
+
8
+ from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding
9
+ from .sub_quadratic_attention import efficient_dot_product_attention
10
+
11
+ from ldm_patched.modules import model_management
12
+
13
+ if model_management.xformers_enabled():
14
+ import xformers
15
+ import xformers.ops
16
+
17
+ from ldm_patched.modules.args_parser import args
18
+ import ldm_patched.modules.ops
19
+ ops = ldm_patched.modules.ops.disable_weight_init
20
+
21
+ # CrossAttn precision handling
22
+ if args.disable_attention_upcast:
23
+ print("disabling upcasting of attention")
24
+ _ATTN_PRECISION = "fp16"
25
+ else:
26
+ _ATTN_PRECISION = "fp32"
27
+
28
+
29
+ def exists(val):
30
+ return val is not None
31
+
32
+
33
+ def uniq(arr):
34
+ return{el: True for el in arr}.keys()
35
+
36
+
37
+ def default(val, d):
38
+ if exists(val):
39
+ return val
40
+ return d
41
+
42
+
43
+ def max_neg_value(t):
44
+ return -torch.finfo(t.dtype).max
45
+
46
+
47
+ def init_(tensor):
48
+ dim = tensor.shape[-1]
49
+ std = 1 / math.sqrt(dim)
50
+ tensor.uniform_(-std, std)
51
+ return tensor
52
+
53
+
54
+ # feedforward
55
+ class GEGLU(nn.Module):
56
+ def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
57
+ super().__init__()
58
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)
59
+
60
+ def forward(self, x):
61
+ x, gate = self.proj(x).chunk(2, dim=-1)
62
+ return x * F.gelu(gate)
63
+
64
+
65
+ class FeedForward(nn.Module):
66
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
67
+ super().__init__()
68
+ inner_dim = int(dim * mult)
69
+ dim_out = default(dim_out, dim)
70
+ project_in = nn.Sequential(
71
+ operations.Linear(dim, inner_dim, dtype=dtype, device=device),
72
+ nn.GELU()
73
+ ) if not glu else GEGLU(dim, inner_dim, dtype=dtype, device=device, operations=operations)
74
+
75
+ self.net = nn.Sequential(
76
+ project_in,
77
+ nn.Dropout(dropout),
78
+ operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.net(x)
83
+
84
+ def Normalize(in_channels, dtype=None, device=None):
85
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
86
+
87
+ def attention_basic(q, k, v, heads, mask=None):
88
+ b, _, dim_head = q.shape
89
+ dim_head //= heads
90
+ scale = dim_head ** -0.5
91
+
92
+ h = heads
93
+ q, k, v = map(
94
+ lambda t: t.unsqueeze(3)
95
+ .reshape(b, -1, heads, dim_head)
96
+ .permute(0, 2, 1, 3)
97
+ .reshape(b * heads, -1, dim_head)
98
+ .contiguous(),
99
+ (q, k, v),
100
+ )
101
+
102
+ # force cast to fp32 to avoid overflowing
103
+ if _ATTN_PRECISION =="fp32":
104
+ sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
105
+ else:
106
+ sim = einsum('b i d, b j d -> b i j', q, k) * scale
107
+
108
+ del q, k
109
+
110
+ if exists(mask):
111
+ if mask.dtype == torch.bool:
112
+ mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
113
+ max_neg_value = -torch.finfo(sim.dtype).max
114
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
115
+ sim.masked_fill_(~mask, max_neg_value)
116
+ else:
117
+ sim += mask
118
+
119
+ # attention, what we cannot get enough of
120
+ sim = sim.softmax(dim=-1)
121
+
122
+ out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v)
123
+ out = (
124
+ out.unsqueeze(0)
125
+ .reshape(b, heads, -1, dim_head)
126
+ .permute(0, 2, 1, 3)
127
+ .reshape(b, -1, heads * dim_head)
128
+ )
129
+ return out
130
+
131
+
132
+ def attention_sub_quad(query, key, value, heads, mask=None):
133
+ b, _, dim_head = query.shape
134
+ dim_head //= heads
135
+
136
+ scale = dim_head ** -0.5
137
+ query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
138
+ value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
139
+
140
+ key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
141
+
142
+ dtype = query.dtype
143
+ upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
144
+ if upcast_attention:
145
+ bytes_per_token = torch.finfo(torch.float32).bits//8
146
+ else:
147
+ bytes_per_token = torch.finfo(query.dtype).bits//8
148
+ batch_x_heads, q_tokens, _ = query.shape
149
+ _, _, k_tokens = key.shape
150
+ qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
151
+
152
+ mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
153
+
154
+ kv_chunk_size_min = None
155
+ kv_chunk_size = None
156
+ query_chunk_size = None
157
+
158
+ for x in [4096, 2048, 1024, 512, 256]:
159
+ count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
160
+ if count >= k_tokens:
161
+ kv_chunk_size = k_tokens
162
+ query_chunk_size = x
163
+ break
164
+
165
+ if query_chunk_size is None:
166
+ query_chunk_size = 512
167
+
168
+ hidden_states = efficient_dot_product_attention(
169
+ query,
170
+ key,
171
+ value,
172
+ query_chunk_size=query_chunk_size,
173
+ kv_chunk_size=kv_chunk_size,
174
+ kv_chunk_size_min=kv_chunk_size_min,
175
+ use_checkpoint=False,
176
+ upcast_attention=upcast_attention,
177
+ mask=mask,
178
+ )
179
+
180
+ hidden_states = hidden_states.to(dtype)
181
+
182
+ hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
183
+ return hidden_states
184
+
185
+ def attention_split(q, k, v, heads, mask=None):
186
+ b, _, dim_head = q.shape
187
+ dim_head //= heads
188
+ scale = dim_head ** -0.5
189
+
190
+ h = heads
191
+ q, k, v = map(
192
+ lambda t: t.unsqueeze(3)
193
+ .reshape(b, -1, heads, dim_head)
194
+ .permute(0, 2, 1, 3)
195
+ .reshape(b * heads, -1, dim_head)
196
+ .contiguous(),
197
+ (q, k, v),
198
+ )
199
+
200
+ r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
201
+
202
+ mem_free_total = model_management.get_free_memory(q.device)
203
+
204
+ if _ATTN_PRECISION =="fp32":
205
+ element_size = 4
206
+ else:
207
+ element_size = q.element_size()
208
+
209
+ gb = 1024 ** 3
210
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
211
+ modifier = 3
212
+ mem_required = tensor_size * modifier
213
+ steps = 1
214
+
215
+
216
+ if mem_required > mem_free_total:
217
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
218
+ # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
219
+ # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
220
+
221
+ if steps > 64:
222
+ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
223
+ raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
224
+ f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
225
+
226
+ # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size)
227
+ first_op_done = False
228
+ cleared_cache = False
229
+ while True:
230
+ try:
231
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
232
+ for i in range(0, q.shape[1], slice_size):
233
+ end = i + slice_size
234
+ if _ATTN_PRECISION =="fp32":
235
+ with torch.autocast(enabled=False, device_type = 'cuda'):
236
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
237
+ else:
238
+ s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
239
+
240
+ if mask is not None:
241
+ if len(mask.shape) == 2:
242
+ s1 += mask[i:end]
243
+ else:
244
+ s1 += mask[:, i:end]
245
+
246
+ s2 = s1.softmax(dim=-1).to(v.dtype)
247
+ del s1
248
+ first_op_done = True
249
+
250
+ r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
251
+ del s2
252
+ break
253
+ except model_management.OOM_EXCEPTION as e:
254
+ if first_op_done == False:
255
+ model_management.soft_empty_cache(True)
256
+ if cleared_cache == False:
257
+ cleared_cache = True
258
+ print("out of memory error, emptying cache and trying again")
259
+ continue
260
+ steps *= 2
261
+ if steps > 64:
262
+ raise e
263
+ print("out of memory error, increasing steps and trying again", steps)
264
+ else:
265
+ raise e
266
+
267
+ del q, k, v
268
+
269
+ r1 = (
270
+ r1.unsqueeze(0)
271
+ .reshape(b, heads, -1, dim_head)
272
+ .permute(0, 2, 1, 3)
273
+ .reshape(b, -1, heads * dim_head)
274
+ )
275
+ return r1
276
+
277
+ BROKEN_XFORMERS = False
278
+ try:
279
+ x_vers = xformers.__version__
280
+ #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error)
281
+ BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23")
282
+ except:
283
+ pass
284
+
285
+ def attention_xformers(q, k, v, heads, mask=None):
286
+ b, _, dim_head = q.shape
287
+ dim_head //= heads
288
+ if BROKEN_XFORMERS:
289
+ if b * heads > 65535:
290
+ return attention_pytorch(q, k, v, heads, mask)
291
+
292
+ q, k, v = map(
293
+ lambda t: t.unsqueeze(3)
294
+ .reshape(b, -1, heads, dim_head)
295
+ .permute(0, 2, 1, 3)
296
+ .reshape(b * heads, -1, dim_head)
297
+ .contiguous(),
298
+ (q, k, v),
299
+ )
300
+
301
+ if mask is not None:
302
+ pad = 8 - q.shape[1] % 8
303
+ mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
304
+ mask_out[:, :, :mask.shape[-1]] = mask
305
+ mask = mask_out[:, :, :mask.shape[-1]]
306
+
307
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
308
+
309
+ out = (
310
+ out.unsqueeze(0)
311
+ .reshape(b, heads, -1, dim_head)
312
+ .permute(0, 2, 1, 3)
313
+ .reshape(b, -1, heads * dim_head)
314
+ )
315
+ return out
316
+
317
+ def attention_pytorch(q, k, v, heads, mask=None):
318
+ b, _, dim_head = q.shape
319
+ dim_head //= heads
320
+ q, k, v = map(
321
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
322
+ (q, k, v),
323
+ )
324
+
325
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
326
+ out = (
327
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
328
+ )
329
+ return out
330
+
331
+
332
+ optimized_attention = attention_basic
333
+
334
+ if model_management.xformers_enabled():
335
+ print("Using xformers cross attention")
336
+ optimized_attention = attention_xformers
337
+ elif model_management.pytorch_attention_enabled():
338
+ print("Using pytorch cross attention")
339
+ optimized_attention = attention_pytorch
340
+ else:
341
+ if args.attention_split:
342
+ print("Using split optimization for cross attention")
343
+ optimized_attention = attention_split
344
+ else:
345
+ print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --attention-split")
346
+ optimized_attention = attention_sub_quad
347
+
348
+ optimized_attention_masked = optimized_attention
349
+
350
+ def optimized_attention_for_device(device, mask=False, small_input=False):
351
+ if small_input:
352
+ if model_management.pytorch_attention_enabled():
353
+ return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
354
+ else:
355
+ return attention_basic
356
+
357
+ if device == torch.device("cpu"):
358
+ return attention_sub_quad
359
+
360
+ if mask:
361
+ return optimized_attention_masked
362
+
363
+ return optimized_attention
364
+
365
+
366
+ class CrossAttention(nn.Module):
367
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
368
+ super().__init__()
369
+ inner_dim = dim_head * heads
370
+ context_dim = default(context_dim, query_dim)
371
+
372
+ self.heads = heads
373
+ self.dim_head = dim_head
374
+
375
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
376
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
377
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
378
+
379
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
380
+
381
+ def forward(self, x, context=None, value=None, mask=None):
382
+ q = self.to_q(x)
383
+ context = default(context, x)
384
+ k = self.to_k(context)
385
+ if value is not None:
386
+ v = self.to_v(value)
387
+ del value
388
+ else:
389
+ v = self.to_v(context)
390
+
391
+ if mask is None:
392
+ out = optimized_attention(q, k, v, self.heads)
393
+ else:
394
+ out = optimized_attention_masked(q, k, v, self.heads, mask)
395
+ return self.to_out(out)
396
+
397
+
398
+ class BasicTransformerBlock(nn.Module):
399
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
400
+ disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
401
+ super().__init__()
402
+
403
+ self.ff_in = ff_in or inner_dim is not None
404
+ if inner_dim is None:
405
+ inner_dim = dim
406
+
407
+ self.is_res = inner_dim == dim
408
+
409
+ if self.ff_in:
410
+ self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
411
+ self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
412
+
413
+ self.disable_self_attn = disable_self_attn
414
+ self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
415
+ context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
416
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
417
+
418
+ if disable_temporal_crossattention:
419
+ if switch_temporal_ca_to_sa:
420
+ raise ValueError
421
+ else:
422
+ self.attn2 = None
423
+ else:
424
+ context_dim_attn2 = None
425
+ if not switch_temporal_ca_to_sa:
426
+ context_dim_attn2 = context_dim
427
+
428
+ self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
429
+ heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
430
+ self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
431
+
432
+ self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
433
+ self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
434
+ self.checkpoint = checkpoint
435
+ self.n_heads = n_heads
436
+ self.d_head = d_head
437
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
438
+
439
+ def forward(self, x, context=None, transformer_options={}):
440
+ return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
441
+
442
+ def _forward(self, x, context=None, transformer_options={}):
443
+ extra_options = {}
444
+ block = transformer_options.get("block", None)
445
+ block_index = transformer_options.get("block_index", 0)
446
+ transformer_patches = {}
447
+ transformer_patches_replace = {}
448
+
449
+ for k in transformer_options:
450
+ if k == "patches":
451
+ transformer_patches = transformer_options[k]
452
+ elif k == "patches_replace":
453
+ transformer_patches_replace = transformer_options[k]
454
+ else:
455
+ extra_options[k] = transformer_options[k]
456
+
457
+ extra_options["n_heads"] = self.n_heads
458
+ extra_options["dim_head"] = self.d_head
459
+
460
+ if self.ff_in:
461
+ x_skip = x
462
+ x = self.ff_in(self.norm_in(x))
463
+ if self.is_res:
464
+ x += x_skip
465
+
466
+ n = self.norm1(x)
467
+ if self.disable_self_attn:
468
+ context_attn1 = context
469
+ else:
470
+ context_attn1 = None
471
+ value_attn1 = None
472
+
473
+ if "attn1_patch" in transformer_patches:
474
+ patch = transformer_patches["attn1_patch"]
475
+ if context_attn1 is None:
476
+ context_attn1 = n
477
+ value_attn1 = context_attn1
478
+ for p in patch:
479
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
480
+
481
+ if block is not None:
482
+ transformer_block = (block[0], block[1], block_index)
483
+ else:
484
+ transformer_block = None
485
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
486
+ block_attn1 = transformer_block
487
+ if block_attn1 not in attn1_replace_patch:
488
+ block_attn1 = block
489
+
490
+ if block_attn1 in attn1_replace_patch:
491
+ if context_attn1 is None:
492
+ context_attn1 = n
493
+ value_attn1 = n
494
+ n = self.attn1.to_q(n)
495
+ context_attn1 = self.attn1.to_k(context_attn1)
496
+ value_attn1 = self.attn1.to_v(value_attn1)
497
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
498
+ n = self.attn1.to_out(n)
499
+ else:
500
+ n = self.attn1(n, context=context_attn1, value=value_attn1)
501
+
502
+ if "attn1_output_patch" in transformer_patches:
503
+ patch = transformer_patches["attn1_output_patch"]
504
+ for p in patch:
505
+ n = p(n, extra_options)
506
+
507
+ x += n
508
+ if "middle_patch" in transformer_patches:
509
+ patch = transformer_patches["middle_patch"]
510
+ for p in patch:
511
+ x = p(x, extra_options)
512
+
513
+ if self.attn2 is not None:
514
+ n = self.norm2(x)
515
+ if self.switch_temporal_ca_to_sa:
516
+ context_attn2 = n
517
+ else:
518
+ context_attn2 = context
519
+ value_attn2 = None
520
+ if "attn2_patch" in transformer_patches:
521
+ patch = transformer_patches["attn2_patch"]
522
+ value_attn2 = context_attn2
523
+ for p in patch:
524
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
525
+
526
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
527
+ block_attn2 = transformer_block
528
+ if block_attn2 not in attn2_replace_patch:
529
+ block_attn2 = block
530
+
531
+ if block_attn2 in attn2_replace_patch:
532
+ if value_attn2 is None:
533
+ value_attn2 = context_attn2
534
+ n = self.attn2.to_q(n)
535
+ context_attn2 = self.attn2.to_k(context_attn2)
536
+ value_attn2 = self.attn2.to_v(value_attn2)
537
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
538
+ n = self.attn2.to_out(n)
539
+ else:
540
+ n = self.attn2(n, context=context_attn2, value=value_attn2)
541
+
542
+ if "attn2_output_patch" in transformer_patches:
543
+ patch = transformer_patches["attn2_output_patch"]
544
+ for p in patch:
545
+ n = p(n, extra_options)
546
+
547
+ x += n
548
+ if self.is_res:
549
+ x_skip = x
550
+ x = self.ff(self.norm3(x))
551
+ if self.is_res:
552
+ x += x_skip
553
+
554
+ return x
555
+
556
+
557
+ class SpatialTransformer(nn.Module):
558
+ """
559
+ Transformer block for image-like data.
560
+ First, project the input (aka embedding)
561
+ and reshape to b, t, d.
562
+ Then apply standard transformer action.
563
+ Finally, reshape to image
564
+ NEW: use_linear for more efficiency instead of the 1x1 convs
565
+ """
566
+ def __init__(self, in_channels, n_heads, d_head,
567
+ depth=1, dropout=0., context_dim=None,
568
+ disable_self_attn=False, use_linear=False,
569
+ use_checkpoint=True, dtype=None, device=None, operations=ops):
570
+ super().__init__()
571
+ if exists(context_dim) and not isinstance(context_dim, list):
572
+ context_dim = [context_dim] * depth
573
+ self.in_channels = in_channels
574
+ inner_dim = n_heads * d_head
575
+ self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
576
+ if not use_linear:
577
+ self.proj_in = operations.Conv2d(in_channels,
578
+ inner_dim,
579
+ kernel_size=1,
580
+ stride=1,
581
+ padding=0, dtype=dtype, device=device)
582
+ else:
583
+ self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
584
+
585
+ self.transformer_blocks = nn.ModuleList(
586
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
587
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
588
+ for d in range(depth)]
589
+ )
590
+ if not use_linear:
591
+ self.proj_out = operations.Conv2d(inner_dim,in_channels,
592
+ kernel_size=1,
593
+ stride=1,
594
+ padding=0, dtype=dtype, device=device)
595
+ else:
596
+ self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
597
+ self.use_linear = use_linear
598
+
599
+ def forward(self, x, context=None, transformer_options={}):
600
+ # note: if no context is given, cross-attention defaults to self-attention
601
+ if not isinstance(context, list):
602
+ context = [context] * len(self.transformer_blocks)
603
+ b, c, h, w = x.shape
604
+ x_in = x
605
+ x = self.norm(x)
606
+ if not self.use_linear:
607
+ x = self.proj_in(x)
608
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
609
+ if self.use_linear:
610
+ x = self.proj_in(x)
611
+ for i, block in enumerate(self.transformer_blocks):
612
+ transformer_options["block_index"] = i
613
+ x = block(x, context=context[i], transformer_options=transformer_options)
614
+ if self.use_linear:
615
+ x = self.proj_out(x)
616
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
617
+ if not self.use_linear:
618
+ x = self.proj_out(x)
619
+ return x + x_in
620
+
621
+
622
+ class SpatialVideoTransformer(SpatialTransformer):
623
+ def __init__(
624
+ self,
625
+ in_channels,
626
+ n_heads,
627
+ d_head,
628
+ depth=1,
629
+ dropout=0.0,
630
+ use_linear=False,
631
+ context_dim=None,
632
+ use_spatial_context=False,
633
+ timesteps=None,
634
+ merge_strategy: str = "fixed",
635
+ merge_factor: float = 0.5,
636
+ time_context_dim=None,
637
+ ff_in=False,
638
+ checkpoint=False,
639
+ time_depth=1,
640
+ disable_self_attn=False,
641
+ disable_temporal_crossattention=False,
642
+ max_time_embed_period: int = 10000,
643
+ dtype=None, device=None, operations=ops
644
+ ):
645
+ super().__init__(
646
+ in_channels,
647
+ n_heads,
648
+ d_head,
649
+ depth=depth,
650
+ dropout=dropout,
651
+ use_checkpoint=checkpoint,
652
+ context_dim=context_dim,
653
+ use_linear=use_linear,
654
+ disable_self_attn=disable_self_attn,
655
+ dtype=dtype, device=device, operations=operations
656
+ )
657
+ self.time_depth = time_depth
658
+ self.depth = depth
659
+ self.max_time_embed_period = max_time_embed_period
660
+
661
+ time_mix_d_head = d_head
662
+ n_time_mix_heads = n_heads
663
+
664
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
665
+
666
+ inner_dim = n_heads * d_head
667
+ if use_spatial_context:
668
+ time_context_dim = context_dim
669
+
670
+ self.time_stack = nn.ModuleList(
671
+ [
672
+ BasicTransformerBlock(
673
+ inner_dim,
674
+ n_time_mix_heads,
675
+ time_mix_d_head,
676
+ dropout=dropout,
677
+ context_dim=time_context_dim,
678
+ # timesteps=timesteps,
679
+ checkpoint=checkpoint,
680
+ ff_in=ff_in,
681
+ inner_dim=time_mix_inner_dim,
682
+ disable_self_attn=disable_self_attn,
683
+ disable_temporal_crossattention=disable_temporal_crossattention,
684
+ dtype=dtype, device=device, operations=operations
685
+ )
686
+ for _ in range(self.depth)
687
+ ]
688
+ )
689
+
690
+ assert len(self.time_stack) == len(self.transformer_blocks)
691
+
692
+ self.use_spatial_context = use_spatial_context
693
+ self.in_channels = in_channels
694
+
695
+ time_embed_dim = self.in_channels * 4
696
+ self.time_pos_embed = nn.Sequential(
697
+ operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device),
698
+ nn.SiLU(),
699
+ operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device),
700
+ )
701
+
702
+ self.time_mixer = AlphaBlender(
703
+ alpha=merge_factor, merge_strategy=merge_strategy
704
+ )
705
+
706
+ def forward(
707
+ self,
708
+ x: torch.Tensor,
709
+ context: Optional[torch.Tensor] = None,
710
+ time_context: Optional[torch.Tensor] = None,
711
+ timesteps: Optional[int] = None,
712
+ image_only_indicator: Optional[torch.Tensor] = None,
713
+ transformer_options={}
714
+ ) -> torch.Tensor:
715
+ _, _, h, w = x.shape
716
+ x_in = x
717
+ spatial_context = None
718
+ if exists(context):
719
+ spatial_context = context
720
+
721
+ if self.use_spatial_context:
722
+ assert (
723
+ context.ndim == 3
724
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
725
+
726
+ if time_context is None:
727
+ time_context = context
728
+ time_context_first_timestep = time_context[::timesteps]
729
+ time_context = repeat(
730
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
731
+ )
732
+ elif time_context is not None and not self.use_spatial_context:
733
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
734
+ if time_context.ndim == 2:
735
+ time_context = rearrange(time_context, "b c -> b 1 c")
736
+
737
+ x = self.norm(x)
738
+ if not self.use_linear:
739
+ x = self.proj_in(x)
740
+ x = rearrange(x, "b c h w -> b (h w) c")
741
+ if self.use_linear:
742
+ x = self.proj_in(x)
743
+
744
+ num_frames = torch.arange(timesteps, device=x.device)
745
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
746
+ num_frames = rearrange(num_frames, "b t -> (b t)")
747
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype)
748
+ emb = self.time_pos_embed(t_emb)
749
+ emb = emb[:, None, :]
750
+
751
+ for it_, (block, mix_block) in enumerate(
752
+ zip(self.transformer_blocks, self.time_stack)
753
+ ):
754
+ transformer_options["block_index"] = it_
755
+ x = block(
756
+ x,
757
+ context=spatial_context,
758
+ transformer_options=transformer_options,
759
+ )
760
+
761
+ x_mix = x
762
+ x_mix = x_mix + emb
763
+
764
+ B, S, C = x_mix.shape
765
+ x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
766
+ x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
767
+ x_mix = rearrange(
768
+ x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
769
+ )
770
+
771
+ x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator)
772
+
773
+ if self.use_linear:
774
+ x = self.proj_out(x)
775
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
776
+ if not self.use_linear:
777
+ x = self.proj_out(x)
778
+ out = x + x_in
779
+ return out
780
+
781
+
ldm_patched/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
ldm_patched/ldm/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+ from typing import Optional, Any
8
+
9
+ from ldm_patched.modules import model_management
10
+ import ldm_patched.modules.ops
11
+ ops = ldm_patched.modules.ops.disable_weight_init
12
+
13
+ if model_management.xformers_enabled_vae():
14
+ import xformers
15
+ import xformers.ops
16
+
17
+ def get_timestep_embedding(timesteps, embedding_dim):
18
+ """
19
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
20
+ From Fairseq.
21
+ Build sinusoidal embeddings.
22
+ This matches the implementation in tensor2tensor, but differs slightly
23
+ from the description in Section 3.5 of "Attention Is All You Need".
24
+ """
25
+ assert len(timesteps.shape) == 1
26
+
27
+ half_dim = embedding_dim // 2
28
+ emb = math.log(10000) / (half_dim - 1)
29
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
30
+ emb = emb.to(device=timesteps.device)
31
+ emb = timesteps.float()[:, None] * emb[None, :]
32
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
33
+ if embedding_dim % 2 == 1: # zero pad
34
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
35
+ return emb
36
+
37
+
38
+ def nonlinearity(x):
39
+ # swish
40
+ return x*torch.sigmoid(x)
41
+
42
+
43
+ def Normalize(in_channels, num_groups=32):
44
+ return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
45
+
46
+
47
+ class Upsample(nn.Module):
48
+ def __init__(self, in_channels, with_conv):
49
+ super().__init__()
50
+ self.with_conv = with_conv
51
+ if self.with_conv:
52
+ self.conv = ops.Conv2d(in_channels,
53
+ in_channels,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1)
57
+
58
+ def forward(self, x):
59
+ try:
60
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
61
+ except: #operation not implemented for bf16
62
+ b, c, h, w = x.shape
63
+ out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
64
+ split = 8
65
+ l = out.shape[1] // split
66
+ for i in range(0, out.shape[1], l):
67
+ out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
68
+ del x
69
+ x = out
70
+
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # no asymmetric padding in torch conv, must do it ourselves
82
+ self.conv = ops.Conv2d(in_channels,
83
+ in_channels,
84
+ kernel_size=3,
85
+ stride=2,
86
+ padding=0)
87
+
88
+ def forward(self, x):
89
+ if self.with_conv:
90
+ pad = (0,1,0,1)
91
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
92
+ x = self.conv(x)
93
+ else:
94
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
95
+ return x
96
+
97
+
98
+ class ResnetBlock(nn.Module):
99
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
100
+ dropout, temb_channels=512):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+ out_channels = in_channels if out_channels is None else out_channels
104
+ self.out_channels = out_channels
105
+ self.use_conv_shortcut = conv_shortcut
106
+
107
+ self.swish = torch.nn.SiLU(inplace=True)
108
+ self.norm1 = Normalize(in_channels)
109
+ self.conv1 = ops.Conv2d(in_channels,
110
+ out_channels,
111
+ kernel_size=3,
112
+ stride=1,
113
+ padding=1)
114
+ if temb_channels > 0:
115
+ self.temb_proj = ops.Linear(temb_channels,
116
+ out_channels)
117
+ self.norm2 = Normalize(out_channels)
118
+ self.dropout = torch.nn.Dropout(dropout, inplace=True)
119
+ self.conv2 = ops.Conv2d(out_channels,
120
+ out_channels,
121
+ kernel_size=3,
122
+ stride=1,
123
+ padding=1)
124
+ if self.in_channels != self.out_channels:
125
+ if self.use_conv_shortcut:
126
+ self.conv_shortcut = ops.Conv2d(in_channels,
127
+ out_channels,
128
+ kernel_size=3,
129
+ stride=1,
130
+ padding=1)
131
+ else:
132
+ self.nin_shortcut = ops.Conv2d(in_channels,
133
+ out_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+
138
+ def forward(self, x, temb):
139
+ h = x
140
+ h = self.norm1(h)
141
+ h = self.swish(h)
142
+ h = self.conv1(h)
143
+
144
+ if temb is not None:
145
+ h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
146
+
147
+ h = self.norm2(h)
148
+ h = self.swish(h)
149
+ h = self.dropout(h)
150
+ h = self.conv2(h)
151
+
152
+ if self.in_channels != self.out_channels:
153
+ if self.use_conv_shortcut:
154
+ x = self.conv_shortcut(x)
155
+ else:
156
+ x = self.nin_shortcut(x)
157
+
158
+ return x+h
159
+
160
+ def slice_attention(q, k, v):
161
+ r1 = torch.zeros_like(k, device=q.device)
162
+ scale = (int(q.shape[-1])**(-0.5))
163
+
164
+ mem_free_total = model_management.get_free_memory(q.device)
165
+
166
+ gb = 1024 ** 3
167
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
168
+ modifier = 3 if q.element_size() == 2 else 2.5
169
+ mem_required = tensor_size * modifier
170
+ steps = 1
171
+
172
+ if mem_required > mem_free_total:
173
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
174
+
175
+ while True:
176
+ try:
177
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
178
+ for i in range(0, q.shape[1], slice_size):
179
+ end = i + slice_size
180
+ s1 = torch.bmm(q[:, i:end], k) * scale
181
+
182
+ s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
183
+ del s1
184
+
185
+ r1[:, :, i:end] = torch.bmm(v, s2)
186
+ del s2
187
+ break
188
+ except model_management.OOM_EXCEPTION as e:
189
+ model_management.soft_empty_cache(True)
190
+ steps *= 2
191
+ if steps > 128:
192
+ raise e
193
+ print("out of memory error, increasing steps and trying again", steps)
194
+
195
+ return r1
196
+
197
+ def normal_attention(q, k, v):
198
+ # compute attention
199
+ b,c,h,w = q.shape
200
+
201
+ q = q.reshape(b,c,h*w)
202
+ q = q.permute(0,2,1) # b,hw,c
203
+ k = k.reshape(b,c,h*w) # b,c,hw
204
+ v = v.reshape(b,c,h*w)
205
+
206
+ r1 = slice_attention(q, k, v)
207
+ h_ = r1.reshape(b,c,h,w)
208
+ del r1
209
+ return h_
210
+
211
+ def xformers_attention(q, k, v):
212
+ # compute attention
213
+ B, C, H, W = q.shape
214
+ q, k, v = map(
215
+ lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
216
+ (q, k, v),
217
+ )
218
+
219
+ try:
220
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
221
+ out = out.transpose(1, 2).reshape(B, C, H, W)
222
+ except NotImplementedError as e:
223
+ out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
224
+ return out
225
+
226
+ def pytorch_attention(q, k, v):
227
+ # compute attention
228
+ B, C, H, W = q.shape
229
+ q, k, v = map(
230
+ lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
231
+ (q, k, v),
232
+ )
233
+
234
+ try:
235
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
236
+ out = out.transpose(2, 3).reshape(B, C, H, W)
237
+ except model_management.OOM_EXCEPTION as e:
238
+ print("scaled_dot_product_attention OOMed: switched to slice attention")
239
+ out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
240
+ return out
241
+
242
+
243
+ class AttnBlock(nn.Module):
244
+ def __init__(self, in_channels):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+
248
+ self.norm = Normalize(in_channels)
249
+ self.q = ops.Conv2d(in_channels,
250
+ in_channels,
251
+ kernel_size=1,
252
+ stride=1,
253
+ padding=0)
254
+ self.k = ops.Conv2d(in_channels,
255
+ in_channels,
256
+ kernel_size=1,
257
+ stride=1,
258
+ padding=0)
259
+ self.v = ops.Conv2d(in_channels,
260
+ in_channels,
261
+ kernel_size=1,
262
+ stride=1,
263
+ padding=0)
264
+ self.proj_out = ops.Conv2d(in_channels,
265
+ in_channels,
266
+ kernel_size=1,
267
+ stride=1,
268
+ padding=0)
269
+
270
+ if model_management.xformers_enabled_vae():
271
+ print("Using xformers attention in VAE")
272
+ self.optimized_attention = xformers_attention
273
+ elif model_management.pytorch_attention_enabled():
274
+ print("Using pytorch attention in VAE")
275
+ self.optimized_attention = pytorch_attention
276
+ else:
277
+ print("Using split attention in VAE")
278
+ self.optimized_attention = normal_attention
279
+
280
+ def forward(self, x):
281
+ h_ = x
282
+ h_ = self.norm(h_)
283
+ q = self.q(h_)
284
+ k = self.k(h_)
285
+ v = self.v(h_)
286
+
287
+ h_ = self.optimized_attention(q, k, v)
288
+
289
+ h_ = self.proj_out(h_)
290
+
291
+ return x+h_
292
+
293
+
294
+ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
295
+ return AttnBlock(in_channels)
296
+
297
+
298
+ class Model(nn.Module):
299
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
300
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
301
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
302
+ super().__init__()
303
+ if use_linear_attn: attn_type = "linear"
304
+ self.ch = ch
305
+ self.temb_ch = self.ch*4
306
+ self.num_resolutions = len(ch_mult)
307
+ self.num_res_blocks = num_res_blocks
308
+ self.resolution = resolution
309
+ self.in_channels = in_channels
310
+
311
+ self.use_timestep = use_timestep
312
+ if self.use_timestep:
313
+ # timestep embedding
314
+ self.temb = nn.Module()
315
+ self.temb.dense = nn.ModuleList([
316
+ ops.Linear(self.ch,
317
+ self.temb_ch),
318
+ ops.Linear(self.temb_ch,
319
+ self.temb_ch),
320
+ ])
321
+
322
+ # downsampling
323
+ self.conv_in = ops.Conv2d(in_channels,
324
+ self.ch,
325
+ kernel_size=3,
326
+ stride=1,
327
+ padding=1)
328
+
329
+ curr_res = resolution
330
+ in_ch_mult = (1,)+tuple(ch_mult)
331
+ self.down = nn.ModuleList()
332
+ for i_level in range(self.num_resolutions):
333
+ block = nn.ModuleList()
334
+ attn = nn.ModuleList()
335
+ block_in = ch*in_ch_mult[i_level]
336
+ block_out = ch*ch_mult[i_level]
337
+ for i_block in range(self.num_res_blocks):
338
+ block.append(ResnetBlock(in_channels=block_in,
339
+ out_channels=block_out,
340
+ temb_channels=self.temb_ch,
341
+ dropout=dropout))
342
+ block_in = block_out
343
+ if curr_res in attn_resolutions:
344
+ attn.append(make_attn(block_in, attn_type=attn_type))
345
+ down = nn.Module()
346
+ down.block = block
347
+ down.attn = attn
348
+ if i_level != self.num_resolutions-1:
349
+ down.downsample = Downsample(block_in, resamp_with_conv)
350
+ curr_res = curr_res // 2
351
+ self.down.append(down)
352
+
353
+ # middle
354
+ self.mid = nn.Module()
355
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
356
+ out_channels=block_in,
357
+ temb_channels=self.temb_ch,
358
+ dropout=dropout)
359
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
360
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
361
+ out_channels=block_in,
362
+ temb_channels=self.temb_ch,
363
+ dropout=dropout)
364
+
365
+ # upsampling
366
+ self.up = nn.ModuleList()
367
+ for i_level in reversed(range(self.num_resolutions)):
368
+ block = nn.ModuleList()
369
+ attn = nn.ModuleList()
370
+ block_out = ch*ch_mult[i_level]
371
+ skip_in = ch*ch_mult[i_level]
372
+ for i_block in range(self.num_res_blocks+1):
373
+ if i_block == self.num_res_blocks:
374
+ skip_in = ch*in_ch_mult[i_level]
375
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
376
+ out_channels=block_out,
377
+ temb_channels=self.temb_ch,
378
+ dropout=dropout))
379
+ block_in = block_out
380
+ if curr_res in attn_resolutions:
381
+ attn.append(make_attn(block_in, attn_type=attn_type))
382
+ up = nn.Module()
383
+ up.block = block
384
+ up.attn = attn
385
+ if i_level != 0:
386
+ up.upsample = Upsample(block_in, resamp_with_conv)
387
+ curr_res = curr_res * 2
388
+ self.up.insert(0, up) # prepend to get consistent order
389
+
390
+ # end
391
+ self.norm_out = Normalize(block_in)
392
+ self.conv_out = ops.Conv2d(block_in,
393
+ out_ch,
394
+ kernel_size=3,
395
+ stride=1,
396
+ padding=1)
397
+
398
+ def forward(self, x, t=None, context=None):
399
+ #assert x.shape[2] == x.shape[3] == self.resolution
400
+ if context is not None:
401
+ # assume aligned context, cat along channel axis
402
+ x = torch.cat((x, context), dim=1)
403
+ if self.use_timestep:
404
+ # timestep embedding
405
+ assert t is not None
406
+ temb = get_timestep_embedding(t, self.ch)
407
+ temb = self.temb.dense[0](temb)
408
+ temb = nonlinearity(temb)
409
+ temb = self.temb.dense[1](temb)
410
+ else:
411
+ temb = None
412
+
413
+ # downsampling
414
+ hs = [self.conv_in(x)]
415
+ for i_level in range(self.num_resolutions):
416
+ for i_block in range(self.num_res_blocks):
417
+ h = self.down[i_level].block[i_block](hs[-1], temb)
418
+ if len(self.down[i_level].attn) > 0:
419
+ h = self.down[i_level].attn[i_block](h)
420
+ hs.append(h)
421
+ if i_level != self.num_resolutions-1:
422
+ hs.append(self.down[i_level].downsample(hs[-1]))
423
+
424
+ # middle
425
+ h = hs[-1]
426
+ h = self.mid.block_1(h, temb)
427
+ h = self.mid.attn_1(h)
428
+ h = self.mid.block_2(h, temb)
429
+
430
+ # upsampling
431
+ for i_level in reversed(range(self.num_resolutions)):
432
+ for i_block in range(self.num_res_blocks+1):
433
+ h = self.up[i_level].block[i_block](
434
+ torch.cat([h, hs.pop()], dim=1), temb)
435
+ if len(self.up[i_level].attn) > 0:
436
+ h = self.up[i_level].attn[i_block](h)
437
+ if i_level != 0:
438
+ h = self.up[i_level].upsample(h)
439
+
440
+ # end
441
+ h = self.norm_out(h)
442
+ h = nonlinearity(h)
443
+ h = self.conv_out(h)
444
+ return h
445
+
446
+ def get_last_layer(self):
447
+ return self.conv_out.weight
448
+
449
+
450
+ class Encoder(nn.Module):
451
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
452
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
453
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
454
+ **ignore_kwargs):
455
+ super().__init__()
456
+ if use_linear_attn: attn_type = "linear"
457
+ self.ch = ch
458
+ self.temb_ch = 0
459
+ self.num_resolutions = len(ch_mult)
460
+ self.num_res_blocks = num_res_blocks
461
+ self.resolution = resolution
462
+ self.in_channels = in_channels
463
+
464
+ # downsampling
465
+ self.conv_in = ops.Conv2d(in_channels,
466
+ self.ch,
467
+ kernel_size=3,
468
+ stride=1,
469
+ padding=1)
470
+
471
+ curr_res = resolution
472
+ in_ch_mult = (1,)+tuple(ch_mult)
473
+ self.in_ch_mult = in_ch_mult
474
+ self.down = nn.ModuleList()
475
+ for i_level in range(self.num_resolutions):
476
+ block = nn.ModuleList()
477
+ attn = nn.ModuleList()
478
+ block_in = ch*in_ch_mult[i_level]
479
+ block_out = ch*ch_mult[i_level]
480
+ for i_block in range(self.num_res_blocks):
481
+ block.append(ResnetBlock(in_channels=block_in,
482
+ out_channels=block_out,
483
+ temb_channels=self.temb_ch,
484
+ dropout=dropout))
485
+ block_in = block_out
486
+ if curr_res in attn_resolutions:
487
+ attn.append(make_attn(block_in, attn_type=attn_type))
488
+ down = nn.Module()
489
+ down.block = block
490
+ down.attn = attn
491
+ if i_level != self.num_resolutions-1:
492
+ down.downsample = Downsample(block_in, resamp_with_conv)
493
+ curr_res = curr_res // 2
494
+ self.down.append(down)
495
+
496
+ # middle
497
+ self.mid = nn.Module()
498
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
499
+ out_channels=block_in,
500
+ temb_channels=self.temb_ch,
501
+ dropout=dropout)
502
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
503
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
504
+ out_channels=block_in,
505
+ temb_channels=self.temb_ch,
506
+ dropout=dropout)
507
+
508
+ # end
509
+ self.norm_out = Normalize(block_in)
510
+ self.conv_out = ops.Conv2d(block_in,
511
+ 2*z_channels if double_z else z_channels,
512
+ kernel_size=3,
513
+ stride=1,
514
+ padding=1)
515
+
516
+ def forward(self, x):
517
+ # timestep embedding
518
+ temb = None
519
+ # downsampling
520
+ h = self.conv_in(x)
521
+ for i_level in range(self.num_resolutions):
522
+ for i_block in range(self.num_res_blocks):
523
+ h = self.down[i_level].block[i_block](h, temb)
524
+ if len(self.down[i_level].attn) > 0:
525
+ h = self.down[i_level].attn[i_block](h)
526
+ if i_level != self.num_resolutions-1:
527
+ h = self.down[i_level].downsample(h)
528
+
529
+ # middle
530
+ h = self.mid.block_1(h, temb)
531
+ h = self.mid.attn_1(h)
532
+ h = self.mid.block_2(h, temb)
533
+
534
+ # end
535
+ h = self.norm_out(h)
536
+ h = nonlinearity(h)
537
+ h = self.conv_out(h)
538
+ return h
539
+
540
+
541
+ class Decoder(nn.Module):
542
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
543
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
544
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
545
+ conv_out_op=ops.Conv2d,
546
+ resnet_op=ResnetBlock,
547
+ attn_op=AttnBlock,
548
+ **ignorekwargs):
549
+ super().__init__()
550
+ if use_linear_attn: attn_type = "linear"
551
+ self.ch = ch
552
+ self.temb_ch = 0
553
+ self.num_resolutions = len(ch_mult)
554
+ self.num_res_blocks = num_res_blocks
555
+ self.resolution = resolution
556
+ self.in_channels = in_channels
557
+ self.give_pre_end = give_pre_end
558
+ self.tanh_out = tanh_out
559
+
560
+ # compute in_ch_mult, block_in and curr_res at lowest res
561
+ in_ch_mult = (1,)+tuple(ch_mult)
562
+ block_in = ch*ch_mult[self.num_resolutions-1]
563
+ curr_res = resolution // 2**(self.num_resolutions-1)
564
+ self.z_shape = (1,z_channels,curr_res,curr_res)
565
+ print("Working with z of shape {} = {} dimensions.".format(
566
+ self.z_shape, np.prod(self.z_shape)))
567
+
568
+ # z to block_in
569
+ self.conv_in = ops.Conv2d(z_channels,
570
+ block_in,
571
+ kernel_size=3,
572
+ stride=1,
573
+ padding=1)
574
+
575
+ # middle
576
+ self.mid = nn.Module()
577
+ self.mid.block_1 = resnet_op(in_channels=block_in,
578
+ out_channels=block_in,
579
+ temb_channels=self.temb_ch,
580
+ dropout=dropout)
581
+ self.mid.attn_1 = attn_op(block_in)
582
+ self.mid.block_2 = resnet_op(in_channels=block_in,
583
+ out_channels=block_in,
584
+ temb_channels=self.temb_ch,
585
+ dropout=dropout)
586
+
587
+ # upsampling
588
+ self.up = nn.ModuleList()
589
+ for i_level in reversed(range(self.num_resolutions)):
590
+ block = nn.ModuleList()
591
+ attn = nn.ModuleList()
592
+ block_out = ch*ch_mult[i_level]
593
+ for i_block in range(self.num_res_blocks+1):
594
+ block.append(resnet_op(in_channels=block_in,
595
+ out_channels=block_out,
596
+ temb_channels=self.temb_ch,
597
+ dropout=dropout))
598
+ block_in = block_out
599
+ if curr_res in attn_resolutions:
600
+ attn.append(attn_op(block_in))
601
+ up = nn.Module()
602
+ up.block = block
603
+ up.attn = attn
604
+ if i_level != 0:
605
+ up.upsample = Upsample(block_in, resamp_with_conv)
606
+ curr_res = curr_res * 2
607
+ self.up.insert(0, up) # prepend to get consistent order
608
+
609
+ # end
610
+ self.norm_out = Normalize(block_in)
611
+ self.conv_out = conv_out_op(block_in,
612
+ out_ch,
613
+ kernel_size=3,
614
+ stride=1,
615
+ padding=1)
616
+
617
+ def forward(self, z, **kwargs):
618
+ #assert z.shape[1:] == self.z_shape[1:]
619
+ self.last_z_shape = z.shape
620
+
621
+ # timestep embedding
622
+ temb = None
623
+
624
+ # z to block_in
625
+ h = self.conv_in(z)
626
+
627
+ # middle
628
+ h = self.mid.block_1(h, temb, **kwargs)
629
+ h = self.mid.attn_1(h, **kwargs)
630
+ h = self.mid.block_2(h, temb, **kwargs)
631
+
632
+ # upsampling
633
+ for i_level in reversed(range(self.num_resolutions)):
634
+ for i_block in range(self.num_res_blocks+1):
635
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
636
+ if len(self.up[i_level].attn) > 0:
637
+ h = self.up[i_level].attn[i_block](h, **kwargs)
638
+ if i_level != 0:
639
+ h = self.up[i_level].upsample(h)
640
+
641
+ # end
642
+ if self.give_pre_end:
643
+ return h
644
+
645
+ h = self.norm_out(h)
646
+ h = nonlinearity(h)
647
+ h = self.conv_out(h, **kwargs)
648
+ if self.tanh_out:
649
+ h = torch.tanh(h)
650
+ return h
ldm_patched/ldm/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ from .util import (
9
+ checkpoint,
10
+ avg_pool_nd,
11
+ zero_module,
12
+ timestep_embedding,
13
+ AlphaBlender,
14
+ )
15
+ from ..attention import SpatialTransformer, SpatialVideoTransformer, default
16
+ from ldm_patched.ldm.util import exists
17
+ import ldm_patched.modules.ops
18
+ ops = ldm_patched.modules.ops.disable_weight_init
19
+
20
+ class TimestepBlock(nn.Module):
21
+ """
22
+ Any module where forward() takes timestep embeddings as a second argument.
23
+ """
24
+
25
+ @abstractmethod
26
+ def forward(self, x, emb):
27
+ """
28
+ Apply the module to `x` given `emb` timestep embeddings.
29
+ """
30
+
31
+ #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
32
+ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
33
+ for layer in ts:
34
+ if isinstance(layer, VideoResBlock):
35
+ x = layer(x, emb, num_video_frames, image_only_indicator)
36
+ elif isinstance(layer, TimestepBlock):
37
+ x = layer(x, emb)
38
+ elif isinstance(layer, SpatialVideoTransformer):
39
+ x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options)
40
+ if "transformer_index" in transformer_options:
41
+ transformer_options["transformer_index"] += 1
42
+ elif isinstance(layer, SpatialTransformer):
43
+ x = layer(x, context, transformer_options)
44
+ if "transformer_index" in transformer_options:
45
+ transformer_options["transformer_index"] += 1
46
+ elif isinstance(layer, Upsample):
47
+ x = layer(x, output_shape=output_shape)
48
+ else:
49
+ x = layer(x)
50
+ return x
51
+
52
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
53
+ """
54
+ A sequential module that passes timestep embeddings to the children that
55
+ support it as an extra input.
56
+ """
57
+
58
+ def forward(self, *args, **kwargs):
59
+ return forward_timestep_embed(self, *args, **kwargs)
60
+
61
+ class Upsample(nn.Module):
62
+ """
63
+ An upsampling layer with an optional convolution.
64
+ :param channels: channels in the inputs and outputs.
65
+ :param use_conv: a bool determining if a convolution is applied.
66
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
67
+ upsampling occurs in the inner-two dimensions.
68
+ """
69
+
70
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
71
+ super().__init__()
72
+ self.channels = channels
73
+ self.out_channels = out_channels or channels
74
+ self.use_conv = use_conv
75
+ self.dims = dims
76
+ if use_conv:
77
+ self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
78
+
79
+ def forward(self, x, output_shape=None):
80
+ assert x.shape[1] == self.channels
81
+ if self.dims == 3:
82
+ shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
83
+ if output_shape is not None:
84
+ shape[1] = output_shape[3]
85
+ shape[2] = output_shape[4]
86
+ else:
87
+ shape = [x.shape[2] * 2, x.shape[3] * 2]
88
+ if output_shape is not None:
89
+ shape[0] = output_shape[2]
90
+ shape[1] = output_shape[3]
91
+
92
+ x = F.interpolate(x, size=shape, mode="nearest")
93
+ if self.use_conv:
94
+ x = self.conv(x)
95
+ return x
96
+
97
+ class Downsample(nn.Module):
98
+ """
99
+ A downsampling layer with an optional convolution.
100
+ :param channels: channels in the inputs and outputs.
101
+ :param use_conv: a bool determining if a convolution is applied.
102
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
103
+ downsampling occurs in the inner-two dimensions.
104
+ """
105
+
106
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.out_channels = out_channels or channels
110
+ self.use_conv = use_conv
111
+ self.dims = dims
112
+ stride = 2 if dims != 3 else (1, 2, 2)
113
+ if use_conv:
114
+ self.op = operations.conv_nd(
115
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
116
+ )
117
+ else:
118
+ assert self.channels == self.out_channels
119
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
120
+
121
+ def forward(self, x):
122
+ assert x.shape[1] == self.channels
123
+ return self.op(x)
124
+
125
+
126
+ class ResBlock(TimestepBlock):
127
+ """
128
+ A residual block that can optionally change the number of channels.
129
+ :param channels: the number of input channels.
130
+ :param emb_channels: the number of timestep embedding channels.
131
+ :param dropout: the rate of dropout.
132
+ :param out_channels: if specified, the number of out channels.
133
+ :param use_conv: if True and out_channels is specified, use a spatial
134
+ convolution instead of a smaller 1x1 convolution to change the
135
+ channels in the skip connection.
136
+ :param dims: determines if the signal is 1D, 2D, or 3D.
137
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
138
+ :param up: if True, use this block for upsampling.
139
+ :param down: if True, use this block for downsampling.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ channels,
145
+ emb_channels,
146
+ dropout,
147
+ out_channels=None,
148
+ use_conv=False,
149
+ use_scale_shift_norm=False,
150
+ dims=2,
151
+ use_checkpoint=False,
152
+ up=False,
153
+ down=False,
154
+ kernel_size=3,
155
+ exchange_temb_dims=False,
156
+ skip_t_emb=False,
157
+ dtype=None,
158
+ device=None,
159
+ operations=ops
160
+ ):
161
+ super().__init__()
162
+ self.channels = channels
163
+ self.emb_channels = emb_channels
164
+ self.dropout = dropout
165
+ self.out_channels = out_channels or channels
166
+ self.use_conv = use_conv
167
+ self.use_checkpoint = use_checkpoint
168
+ self.use_scale_shift_norm = use_scale_shift_norm
169
+ self.exchange_temb_dims = exchange_temb_dims
170
+
171
+ if isinstance(kernel_size, list):
172
+ padding = [k // 2 for k in kernel_size]
173
+ else:
174
+ padding = kernel_size // 2
175
+
176
+ self.in_layers = nn.Sequential(
177
+ operations.GroupNorm(32, channels, dtype=dtype, device=device),
178
+ nn.SiLU(),
179
+ operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
180
+ )
181
+
182
+ self.updown = up or down
183
+
184
+ if up:
185
+ self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
186
+ self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
187
+ elif down:
188
+ self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
189
+ self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
190
+ else:
191
+ self.h_upd = self.x_upd = nn.Identity()
192
+
193
+ self.skip_t_emb = skip_t_emb
194
+ if self.skip_t_emb:
195
+ self.emb_layers = None
196
+ self.exchange_temb_dims = False
197
+ else:
198
+ self.emb_layers = nn.Sequential(
199
+ nn.SiLU(),
200
+ operations.Linear(
201
+ emb_channels,
202
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
203
+ ),
204
+ )
205
+ self.out_layers = nn.Sequential(
206
+ operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
207
+ nn.SiLU(),
208
+ nn.Dropout(p=dropout),
209
+ operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
210
+ ,
211
+ )
212
+
213
+ if self.out_channels == channels:
214
+ self.skip_connection = nn.Identity()
215
+ elif use_conv:
216
+ self.skip_connection = operations.conv_nd(
217
+ dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device
218
+ )
219
+ else:
220
+ self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
221
+
222
+ def forward(self, x, emb):
223
+ """
224
+ Apply the block to a Tensor, conditioned on a timestep embedding.
225
+ :param x: an [N x C x ...] Tensor of features.
226
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
227
+ :return: an [N x C x ...] Tensor of outputs.
228
+ """
229
+ return checkpoint(
230
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
231
+ )
232
+
233
+
234
+ def _forward(self, x, emb):
235
+ if self.updown:
236
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
237
+ h = in_rest(x)
238
+ h = self.h_upd(h)
239
+ x = self.x_upd(x)
240
+ h = in_conv(h)
241
+ else:
242
+ h = self.in_layers(x)
243
+
244
+ emb_out = None
245
+ if not self.skip_t_emb:
246
+ emb_out = self.emb_layers(emb).type(h.dtype)
247
+ while len(emb_out.shape) < len(h.shape):
248
+ emb_out = emb_out[..., None]
249
+ if self.use_scale_shift_norm:
250
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
251
+ h = out_norm(h)
252
+ if emb_out is not None:
253
+ scale, shift = th.chunk(emb_out, 2, dim=1)
254
+ h *= (1 + scale)
255
+ h += shift
256
+ h = out_rest(h)
257
+ else:
258
+ if emb_out is not None:
259
+ if self.exchange_temb_dims:
260
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
261
+ h = h + emb_out
262
+ h = self.out_layers(h)
263
+ return self.skip_connection(x) + h
264
+
265
+
266
+ class VideoResBlock(ResBlock):
267
+ def __init__(
268
+ self,
269
+ channels: int,
270
+ emb_channels: int,
271
+ dropout: float,
272
+ video_kernel_size=3,
273
+ merge_strategy: str = "fixed",
274
+ merge_factor: float = 0.5,
275
+ out_channels=None,
276
+ use_conv: bool = False,
277
+ use_scale_shift_norm: bool = False,
278
+ dims: int = 2,
279
+ use_checkpoint: bool = False,
280
+ up: bool = False,
281
+ down: bool = False,
282
+ dtype=None,
283
+ device=None,
284
+ operations=ops
285
+ ):
286
+ super().__init__(
287
+ channels,
288
+ emb_channels,
289
+ dropout,
290
+ out_channels=out_channels,
291
+ use_conv=use_conv,
292
+ use_scale_shift_norm=use_scale_shift_norm,
293
+ dims=dims,
294
+ use_checkpoint=use_checkpoint,
295
+ up=up,
296
+ down=down,
297
+ dtype=dtype,
298
+ device=device,
299
+ operations=operations
300
+ )
301
+
302
+ self.time_stack = ResBlock(
303
+ default(out_channels, channels),
304
+ emb_channels,
305
+ dropout=dropout,
306
+ dims=3,
307
+ out_channels=default(out_channels, channels),
308
+ use_scale_shift_norm=False,
309
+ use_conv=False,
310
+ up=False,
311
+ down=False,
312
+ kernel_size=video_kernel_size,
313
+ use_checkpoint=use_checkpoint,
314
+ exchange_temb_dims=True,
315
+ dtype=dtype,
316
+ device=device,
317
+ operations=operations
318
+ )
319
+ self.time_mixer = AlphaBlender(
320
+ alpha=merge_factor,
321
+ merge_strategy=merge_strategy,
322
+ rearrange_pattern="b t -> b 1 t 1 1",
323
+ )
324
+
325
+ def forward(
326
+ self,
327
+ x: th.Tensor,
328
+ emb: th.Tensor,
329
+ num_video_frames: int,
330
+ image_only_indicator = None,
331
+ ) -> th.Tensor:
332
+ x = super().forward(x, emb)
333
+
334
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
335
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
336
+
337
+ x = self.time_stack(
338
+ x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
339
+ )
340
+ x = self.time_mixer(
341
+ x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
342
+ )
343
+ x = rearrange(x, "b c t h w -> (b t) c h w")
344
+ return x
345
+
346
+
347
+ class Timestep(nn.Module):
348
+ def __init__(self, dim):
349
+ super().__init__()
350
+ self.dim = dim
351
+
352
+ def forward(self, t):
353
+ return timestep_embedding(t, self.dim)
354
+
355
+ def apply_control(h, control, name):
356
+ if control is not None and name in control and len(control[name]) > 0:
357
+ ctrl = control[name].pop()
358
+ if ctrl is not None:
359
+ try:
360
+ h += ctrl
361
+ except:
362
+ print("warning control could not be applied", h.shape, ctrl.shape)
363
+ return h
364
+
365
+ class UNetModel(nn.Module):
366
+ """
367
+ The full UNet model with attention and timestep embedding.
368
+ :param in_channels: channels in the input Tensor.
369
+ :param model_channels: base channel count for the model.
370
+ :param out_channels: channels in the output Tensor.
371
+ :param num_res_blocks: number of residual blocks per downsample.
372
+ :param dropout: the dropout probability.
373
+ :param channel_mult: channel multiplier for each level of the UNet.
374
+ :param conv_resample: if True, use learned convolutions for upsampling and
375
+ downsampling.
376
+ :param dims: determines if the signal is 1D, 2D, or 3D.
377
+ :param num_classes: if specified (as an int), then this model will be
378
+ class-conditional with `num_classes` classes.
379
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
380
+ :param num_heads: the number of attention heads in each attention layer.
381
+ :param num_heads_channels: if specified, ignore num_heads and instead use
382
+ a fixed channel width per attention head.
383
+ :param num_heads_upsample: works with num_heads to set a different number
384
+ of heads for upsampling. Deprecated.
385
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
386
+ :param resblock_updown: use residual blocks for up/downsampling.
387
+ :param use_new_attention_order: use a different attention pattern for potentially
388
+ increased efficiency.
389
+ """
390
+
391
+ def __init__(
392
+ self,
393
+ image_size,
394
+ in_channels,
395
+ model_channels,
396
+ out_channels,
397
+ num_res_blocks,
398
+ dropout=0,
399
+ channel_mult=(1, 2, 4, 8),
400
+ conv_resample=True,
401
+ dims=2,
402
+ num_classes=None,
403
+ use_checkpoint=False,
404
+ dtype=th.float32,
405
+ num_heads=-1,
406
+ num_head_channels=-1,
407
+ num_heads_upsample=-1,
408
+ use_scale_shift_norm=False,
409
+ resblock_updown=False,
410
+ use_new_attention_order=False,
411
+ use_spatial_transformer=False, # custom transformer support
412
+ transformer_depth=1, # custom transformer support
413
+ context_dim=None, # custom transformer support
414
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
415
+ legacy=True,
416
+ disable_self_attentions=None,
417
+ num_attention_blocks=None,
418
+ disable_middle_self_attn=False,
419
+ use_linear_in_transformer=False,
420
+ adm_in_channels=None,
421
+ transformer_depth_middle=None,
422
+ transformer_depth_output=None,
423
+ use_temporal_resblock=False,
424
+ use_temporal_attention=False,
425
+ time_context_dim=None,
426
+ extra_ff_mix_layer=False,
427
+ use_spatial_context=False,
428
+ merge_strategy=None,
429
+ merge_factor=0.0,
430
+ video_kernel_size=None,
431
+ disable_temporal_crossattention=False,
432
+ max_ddpm_temb_period=10000,
433
+ device=None,
434
+ operations=ops,
435
+ ):
436
+ super().__init__()
437
+
438
+ if context_dim is not None:
439
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
440
+ # from omegaconf.listconfig import ListConfig
441
+ # if type(context_dim) == ListConfig:
442
+ # context_dim = list(context_dim)
443
+
444
+ if num_heads_upsample == -1:
445
+ num_heads_upsample = num_heads
446
+
447
+ if num_heads == -1:
448
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
449
+
450
+ if num_head_channels == -1:
451
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
452
+
453
+ self.in_channels = in_channels
454
+ self.model_channels = model_channels
455
+ self.out_channels = out_channels
456
+
457
+ if isinstance(num_res_blocks, int):
458
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
459
+ else:
460
+ if len(num_res_blocks) != len(channel_mult):
461
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
462
+ "as a list/tuple (per-level) with the same length as channel_mult")
463
+ self.num_res_blocks = num_res_blocks
464
+
465
+ if disable_self_attentions is not None:
466
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
467
+ assert len(disable_self_attentions) == len(channel_mult)
468
+ if num_attention_blocks is not None:
469
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
470
+
471
+ transformer_depth = transformer_depth[:]
472
+ transformer_depth_output = transformer_depth_output[:]
473
+
474
+ self.dropout = dropout
475
+ self.channel_mult = channel_mult
476
+ self.conv_resample = conv_resample
477
+ self.num_classes = num_classes
478
+ self.use_checkpoint = use_checkpoint
479
+ self.dtype = dtype
480
+ self.num_heads = num_heads
481
+ self.num_head_channels = num_head_channels
482
+ self.num_heads_upsample = num_heads_upsample
483
+ self.use_temporal_resblocks = use_temporal_resblock
484
+ self.predict_codebook_ids = n_embed is not None
485
+
486
+ self.default_num_video_frames = None
487
+ self.default_image_only_indicator = None
488
+
489
+ time_embed_dim = model_channels * 4
490
+ self.time_embed = nn.Sequential(
491
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
492
+ nn.SiLU(),
493
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
494
+ )
495
+
496
+ if self.num_classes is not None:
497
+ if isinstance(self.num_classes, int):
498
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, dtype=self.dtype, device=device)
499
+ elif self.num_classes == "continuous":
500
+ print("setting up linear c_adm embedding layer")
501
+ self.label_emb = nn.Linear(1, time_embed_dim)
502
+ elif self.num_classes == "sequential":
503
+ assert adm_in_channels is not None
504
+ self.label_emb = nn.Sequential(
505
+ nn.Sequential(
506
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
507
+ nn.SiLU(),
508
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
509
+ )
510
+ )
511
+ else:
512
+ raise ValueError()
513
+
514
+ self.input_blocks = nn.ModuleList(
515
+ [
516
+ TimestepEmbedSequential(
517
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
518
+ )
519
+ ]
520
+ )
521
+ self._feature_size = model_channels
522
+ input_block_chans = [model_channels]
523
+ ch = model_channels
524
+ ds = 1
525
+
526
+ def get_attention_layer(
527
+ ch,
528
+ num_heads,
529
+ dim_head,
530
+ depth=1,
531
+ context_dim=None,
532
+ use_checkpoint=False,
533
+ disable_self_attn=False,
534
+ ):
535
+ if use_temporal_attention:
536
+ return SpatialVideoTransformer(
537
+ ch,
538
+ num_heads,
539
+ dim_head,
540
+ depth=depth,
541
+ context_dim=context_dim,
542
+ time_context_dim=time_context_dim,
543
+ dropout=dropout,
544
+ ff_in=extra_ff_mix_layer,
545
+ use_spatial_context=use_spatial_context,
546
+ merge_strategy=merge_strategy,
547
+ merge_factor=merge_factor,
548
+ checkpoint=use_checkpoint,
549
+ use_linear=use_linear_in_transformer,
550
+ disable_self_attn=disable_self_attn,
551
+ disable_temporal_crossattention=disable_temporal_crossattention,
552
+ max_time_embed_period=max_ddpm_temb_period,
553
+ dtype=self.dtype, device=device, operations=operations
554
+ )
555
+ else:
556
+ return SpatialTransformer(
557
+ ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
558
+ disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
559
+ use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
560
+ )
561
+
562
+ def get_resblock(
563
+ merge_factor,
564
+ merge_strategy,
565
+ video_kernel_size,
566
+ ch,
567
+ time_embed_dim,
568
+ dropout,
569
+ out_channels,
570
+ dims,
571
+ use_checkpoint,
572
+ use_scale_shift_norm,
573
+ down=False,
574
+ up=False,
575
+ dtype=None,
576
+ device=None,
577
+ operations=ops
578
+ ):
579
+ if self.use_temporal_resblocks:
580
+ return VideoResBlock(
581
+ merge_factor=merge_factor,
582
+ merge_strategy=merge_strategy,
583
+ video_kernel_size=video_kernel_size,
584
+ channels=ch,
585
+ emb_channels=time_embed_dim,
586
+ dropout=dropout,
587
+ out_channels=out_channels,
588
+ dims=dims,
589
+ use_checkpoint=use_checkpoint,
590
+ use_scale_shift_norm=use_scale_shift_norm,
591
+ down=down,
592
+ up=up,
593
+ dtype=dtype,
594
+ device=device,
595
+ operations=operations
596
+ )
597
+ else:
598
+ return ResBlock(
599
+ channels=ch,
600
+ emb_channels=time_embed_dim,
601
+ dropout=dropout,
602
+ out_channels=out_channels,
603
+ use_checkpoint=use_checkpoint,
604
+ dims=dims,
605
+ use_scale_shift_norm=use_scale_shift_norm,
606
+ down=down,
607
+ up=up,
608
+ dtype=dtype,
609
+ device=device,
610
+ operations=operations
611
+ )
612
+
613
+ for level, mult in enumerate(channel_mult):
614
+ for nr in range(self.num_res_blocks[level]):
615
+ layers = [
616
+ get_resblock(
617
+ merge_factor=merge_factor,
618
+ merge_strategy=merge_strategy,
619
+ video_kernel_size=video_kernel_size,
620
+ ch=ch,
621
+ time_embed_dim=time_embed_dim,
622
+ dropout=dropout,
623
+ out_channels=mult * model_channels,
624
+ dims=dims,
625
+ use_checkpoint=use_checkpoint,
626
+ use_scale_shift_norm=use_scale_shift_norm,
627
+ dtype=self.dtype,
628
+ device=device,
629
+ operations=operations,
630
+ )
631
+ ]
632
+ ch = mult * model_channels
633
+ num_transformers = transformer_depth.pop(0)
634
+ if num_transformers > 0:
635
+ if num_head_channels == -1:
636
+ dim_head = ch // num_heads
637
+ else:
638
+ num_heads = ch // num_head_channels
639
+ dim_head = num_head_channels
640
+ if legacy:
641
+ #num_heads = 1
642
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
643
+ if exists(disable_self_attentions):
644
+ disabled_sa = disable_self_attentions[level]
645
+ else:
646
+ disabled_sa = False
647
+
648
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
649
+ layers.append(get_attention_layer(
650
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
651
+ disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint)
652
+ )
653
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
654
+ self._feature_size += ch
655
+ input_block_chans.append(ch)
656
+ if level != len(channel_mult) - 1:
657
+ out_ch = ch
658
+ self.input_blocks.append(
659
+ TimestepEmbedSequential(
660
+ get_resblock(
661
+ merge_factor=merge_factor,
662
+ merge_strategy=merge_strategy,
663
+ video_kernel_size=video_kernel_size,
664
+ ch=ch,
665
+ time_embed_dim=time_embed_dim,
666
+ dropout=dropout,
667
+ out_channels=out_ch,
668
+ dims=dims,
669
+ use_checkpoint=use_checkpoint,
670
+ use_scale_shift_norm=use_scale_shift_norm,
671
+ down=True,
672
+ dtype=self.dtype,
673
+ device=device,
674
+ operations=operations
675
+ )
676
+ if resblock_updown
677
+ else Downsample(
678
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
679
+ )
680
+ )
681
+ )
682
+ ch = out_ch
683
+ input_block_chans.append(ch)
684
+ ds *= 2
685
+ self._feature_size += ch
686
+
687
+ if num_head_channels == -1:
688
+ dim_head = ch // num_heads
689
+ else:
690
+ num_heads = ch // num_head_channels
691
+ dim_head = num_head_channels
692
+ if legacy:
693
+ #num_heads = 1
694
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
695
+ mid_block = [
696
+ get_resblock(
697
+ merge_factor=merge_factor,
698
+ merge_strategy=merge_strategy,
699
+ video_kernel_size=video_kernel_size,
700
+ ch=ch,
701
+ time_embed_dim=time_embed_dim,
702
+ dropout=dropout,
703
+ out_channels=None,
704
+ dims=dims,
705
+ use_checkpoint=use_checkpoint,
706
+ use_scale_shift_norm=use_scale_shift_norm,
707
+ dtype=self.dtype,
708
+ device=device,
709
+ operations=operations
710
+ )]
711
+ if transformer_depth_middle >= 0:
712
+ mid_block += [get_attention_layer( # always uses a self-attn
713
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
714
+ disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
715
+ ),
716
+ get_resblock(
717
+ merge_factor=merge_factor,
718
+ merge_strategy=merge_strategy,
719
+ video_kernel_size=video_kernel_size,
720
+ ch=ch,
721
+ time_embed_dim=time_embed_dim,
722
+ dropout=dropout,
723
+ out_channels=None,
724
+ dims=dims,
725
+ use_checkpoint=use_checkpoint,
726
+ use_scale_shift_norm=use_scale_shift_norm,
727
+ dtype=self.dtype,
728
+ device=device,
729
+ operations=operations
730
+ )]
731
+ self.middle_block = TimestepEmbedSequential(*mid_block)
732
+ self._feature_size += ch
733
+
734
+ self.output_blocks = nn.ModuleList([])
735
+ for level, mult in list(enumerate(channel_mult))[::-1]:
736
+ for i in range(self.num_res_blocks[level] + 1):
737
+ ich = input_block_chans.pop()
738
+ layers = [
739
+ get_resblock(
740
+ merge_factor=merge_factor,
741
+ merge_strategy=merge_strategy,
742
+ video_kernel_size=video_kernel_size,
743
+ ch=ch + ich,
744
+ time_embed_dim=time_embed_dim,
745
+ dropout=dropout,
746
+ out_channels=model_channels * mult,
747
+ dims=dims,
748
+ use_checkpoint=use_checkpoint,
749
+ use_scale_shift_norm=use_scale_shift_norm,
750
+ dtype=self.dtype,
751
+ device=device,
752
+ operations=operations
753
+ )
754
+ ]
755
+ ch = model_channels * mult
756
+ num_transformers = transformer_depth_output.pop()
757
+ if num_transformers > 0:
758
+ if num_head_channels == -1:
759
+ dim_head = ch // num_heads
760
+ else:
761
+ num_heads = ch // num_head_channels
762
+ dim_head = num_head_channels
763
+ if legacy:
764
+ #num_heads = 1
765
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
766
+ if exists(disable_self_attentions):
767
+ disabled_sa = disable_self_attentions[level]
768
+ else:
769
+ disabled_sa = False
770
+
771
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
772
+ layers.append(
773
+ get_attention_layer(
774
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
775
+ disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint
776
+ )
777
+ )
778
+ if level and i == self.num_res_blocks[level]:
779
+ out_ch = ch
780
+ layers.append(
781
+ get_resblock(
782
+ merge_factor=merge_factor,
783
+ merge_strategy=merge_strategy,
784
+ video_kernel_size=video_kernel_size,
785
+ ch=ch,
786
+ time_embed_dim=time_embed_dim,
787
+ dropout=dropout,
788
+ out_channels=out_ch,
789
+ dims=dims,
790
+ use_checkpoint=use_checkpoint,
791
+ use_scale_shift_norm=use_scale_shift_norm,
792
+ up=True,
793
+ dtype=self.dtype,
794
+ device=device,
795
+ operations=operations
796
+ )
797
+ if resblock_updown
798
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
799
+ )
800
+ ds //= 2
801
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
802
+ self._feature_size += ch
803
+
804
+ self.out = nn.Sequential(
805
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
806
+ nn.SiLU(),
807
+ zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
808
+ )
809
+ if self.predict_codebook_ids:
810
+ self.id_predictor = nn.Sequential(
811
+ operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
812
+ operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
813
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
814
+ )
815
+
816
+ def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
817
+ """
818
+ Apply the model to an input batch.
819
+ :param x: an [N x C x ...] Tensor of inputs.
820
+ :param timesteps: a 1-D batch of timesteps.
821
+ :param context: conditioning plugged in via crossattn
822
+ :param y: an [N] Tensor of labels, if class-conditional.
823
+ :return: an [N x C x ...] Tensor of outputs.
824
+ """
825
+ transformer_options["original_shape"] = list(x.shape)
826
+ transformer_options["transformer_index"] = 0
827
+ transformer_patches = transformer_options.get("patches", {})
828
+
829
+ num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
830
+ image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
831
+ time_context = kwargs.get("time_context", None)
832
+
833
+ assert (y is not None) == (
834
+ self.num_classes is not None
835
+ ), "must specify y if and only if the model is class-conditional"
836
+ hs = []
837
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
838
+ emb = self.time_embed(t_emb)
839
+
840
+ if self.num_classes is not None:
841
+ assert y.shape[0] == x.shape[0]
842
+ emb = emb + self.label_emb(y)
843
+
844
+ h = x
845
+ for id, module in enumerate(self.input_blocks):
846
+ transformer_options["block"] = ("input", id)
847
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
848
+ h = apply_control(h, control, 'input')
849
+ if "input_block_patch" in transformer_patches:
850
+ patch = transformer_patches["input_block_patch"]
851
+ for p in patch:
852
+ h = p(h, transformer_options)
853
+
854
+ hs.append(h)
855
+ if "input_block_patch_after_skip" in transformer_patches:
856
+ patch = transformer_patches["input_block_patch_after_skip"]
857
+ for p in patch:
858
+ h = p(h, transformer_options)
859
+
860
+ transformer_options["block"] = ("middle", 0)
861
+ h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
862
+ h = apply_control(h, control, 'middle')
863
+
864
+
865
+ for id, module in enumerate(self.output_blocks):
866
+ transformer_options["block"] = ("output", id)
867
+ hsp = hs.pop()
868
+ hsp = apply_control(hsp, control, 'output')
869
+
870
+ if "output_block_patch" in transformer_patches:
871
+ patch = transformer_patches["output_block_patch"]
872
+ for p in patch:
873
+ h, hsp = p(h, hsp, transformer_options)
874
+
875
+ h = th.cat([h, hsp], dim=1)
876
+ del hsp
877
+ if len(hs) > 0:
878
+ output_shape = hs[-1].shape
879
+ else:
880
+ output_shape = None
881
+ h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
882
+ h = h.type(x.dtype)
883
+ if self.predict_codebook_ids:
884
+ return self.id_predictor(h)
885
+ else:
886
+ return self.out(h)
ldm_patched/ldm/modules/diffusionmodules/upscaling.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from functools import partial
5
+
6
+ from .util import extract_into_tensor, make_beta_schedule
7
+ from ldm_patched.ldm.util import default
8
+
9
+
10
+ class AbstractLowScaleModel(nn.Module):
11
+ # for concatenating a downsampled image to the latent representation
12
+ def __init__(self, noise_schedule_config=None):
13
+ super(AbstractLowScaleModel, self).__init__()
14
+ if noise_schedule_config is not None:
15
+ self.register_schedule(**noise_schedule_config)
16
+
17
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
18
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
19
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
20
+ cosine_s=cosine_s)
21
+ alphas = 1. - betas
22
+ alphas_cumprod = np.cumprod(alphas, axis=0)
23
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
24
+
25
+ timesteps, = betas.shape
26
+ self.num_timesteps = int(timesteps)
27
+ self.linear_start = linear_start
28
+ self.linear_end = linear_end
29
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
30
+
31
+ to_torch = partial(torch.tensor, dtype=torch.float32)
32
+
33
+ self.register_buffer('betas', to_torch(betas))
34
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
35
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
36
+
37
+ # calculations for diffusion q(x_t | x_{t-1}) and others
38
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
39
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
40
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
41
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
42
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
43
+
44
+ def q_sample(self, x_start, t, noise=None, seed=None):
45
+ if noise is None:
46
+ if seed is None:
47
+ noise = torch.randn_like(x_start)
48
+ else:
49
+ noise = torch.randn(x_start.size(), dtype=x_start.dtype, layout=x_start.layout, generator=torch.manual_seed(seed)).to(x_start.device)
50
+ return (extract_into_tensor(self.sqrt_alphas_cumprod.to(x_start.device), t, x_start.shape) * x_start +
51
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod.to(x_start.device), t, x_start.shape) * noise)
52
+
53
+ def forward(self, x):
54
+ return x, None
55
+
56
+ def decode(self, x):
57
+ return x
58
+
59
+
60
+ class SimpleImageConcat(AbstractLowScaleModel):
61
+ # no noise level conditioning
62
+ def __init__(self):
63
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
64
+ self.max_noise_level = 0
65
+
66
+ def forward(self, x):
67
+ # fix to constant noise level
68
+ return x, torch.zeros(x.shape[0], device=x.device).long()
69
+
70
+
71
+ class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
72
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
73
+ super().__init__(noise_schedule_config=noise_schedule_config)
74
+ self.max_noise_level = max_noise_level
75
+
76
+ def forward(self, x, noise_level=None, seed=None):
77
+ if noise_level is None:
78
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
79
+ else:
80
+ assert isinstance(noise_level, torch.Tensor)
81
+ z = self.q_sample(x, noise_level, seed=seed)
82
+ return z, noise_level
83
+
84
+
85
+
ldm_patched/ldm/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import numpy as np
16
+ from einops import repeat, rearrange
17
+
18
+ from ldm_patched.ldm.util import instantiate_from_config
19
+
20
+ class AlphaBlender(nn.Module):
21
+ strategies = ["learned", "fixed", "learned_with_images"]
22
+
23
+ def __init__(
24
+ self,
25
+ alpha: float,
26
+ merge_strategy: str = "learned_with_images",
27
+ rearrange_pattern: str = "b t -> (b t) 1 1",
28
+ ):
29
+ super().__init__()
30
+ self.merge_strategy = merge_strategy
31
+ self.rearrange_pattern = rearrange_pattern
32
+
33
+ assert (
34
+ merge_strategy in self.strategies
35
+ ), f"merge_strategy needs to be in {self.strategies}"
36
+
37
+ if self.merge_strategy == "fixed":
38
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
39
+ elif (
40
+ self.merge_strategy == "learned"
41
+ or self.merge_strategy == "learned_with_images"
42
+ ):
43
+ self.register_parameter(
44
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
45
+ )
46
+ else:
47
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
48
+
49
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
50
+ # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
51
+ if self.merge_strategy == "fixed":
52
+ # make shape compatible
53
+ # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
54
+ alpha = self.mix_factor.to(image_only_indicator.device)
55
+ elif self.merge_strategy == "learned":
56
+ alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
57
+ # make shape compatible
58
+ # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
59
+ elif self.merge_strategy == "learned_with_images":
60
+ assert image_only_indicator is not None, "need image_only_indicator ..."
61
+ alpha = torch.where(
62
+ image_only_indicator.bool(),
63
+ torch.ones(1, 1, device=image_only_indicator.device),
64
+ rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
65
+ )
66
+ alpha = rearrange(alpha, self.rearrange_pattern)
67
+ # make shape compatible
68
+ # alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
69
+ else:
70
+ raise NotImplementedError()
71
+ return alpha
72
+
73
+ def forward(
74
+ self,
75
+ x_spatial,
76
+ x_temporal,
77
+ image_only_indicator=None,
78
+ ) -> torch.Tensor:
79
+ alpha = self.get_alpha(image_only_indicator)
80
+ x = (
81
+ alpha.to(x_spatial.dtype) * x_spatial
82
+ + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
83
+ )
84
+ return x
85
+
86
+
87
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
88
+ if schedule == "linear":
89
+ betas = (
90
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
91
+ )
92
+
93
+ elif schedule == "cosine":
94
+ timesteps = (
95
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
96
+ )
97
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
98
+ alphas = torch.cos(alphas).pow(2)
99
+ alphas = alphas / alphas[0]
100
+ betas = 1 - alphas[1:] / alphas[:-1]
101
+ betas = np.clip(betas, a_min=0, a_max=0.999)
102
+
103
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
104
+ # return early
105
+ return betas_for_alpha_bar(
106
+ n_timestep,
107
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
108
+ )
109
+
110
+ elif schedule == "sqrt_linear":
111
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
112
+ elif schedule == "sqrt":
113
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
114
+ else:
115
+ raise ValueError(f"schedule '{schedule}' unknown.")
116
+ return betas.numpy()
117
+
118
+
119
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
120
+ if ddim_discr_method == 'uniform':
121
+ c = num_ddpm_timesteps // num_ddim_timesteps
122
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
123
+ elif ddim_discr_method == 'quad':
124
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
125
+ else:
126
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
127
+
128
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
129
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
130
+ steps_out = ddim_timesteps + 1
131
+ if verbose:
132
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
133
+ return steps_out
134
+
135
+
136
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
137
+ # select alphas for computing the variance schedule
138
+ alphas = alphacums[ddim_timesteps]
139
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
140
+
141
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
142
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
143
+ if verbose:
144
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
145
+ print(f'For the chosen value of eta, which is {eta}, '
146
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
147
+ return sigmas, alphas, alphas_prev
148
+
149
+
150
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
151
+ """
152
+ Create a beta schedule that discretizes the given alpha_t_bar function,
153
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
154
+ :param num_diffusion_timesteps: the number of betas to produce.
155
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
156
+ produces the cumulative product of (1-beta) up to that
157
+ part of the diffusion process.
158
+ :param max_beta: the maximum beta to use; use values lower than 1 to
159
+ prevent singularities.
160
+ """
161
+ betas = []
162
+ for i in range(num_diffusion_timesteps):
163
+ t1 = i / num_diffusion_timesteps
164
+ t2 = (i + 1) / num_diffusion_timesteps
165
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
166
+ return np.array(betas)
167
+
168
+
169
+ def extract_into_tensor(a, t, x_shape):
170
+ b, *_ = t.shape
171
+ out = a.gather(-1, t)
172
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
173
+
174
+
175
+ def checkpoint(func, inputs, params, flag):
176
+ """
177
+ Evaluate a function without caching intermediate activations, allowing for
178
+ reduced memory at the expense of extra compute in the backward pass.
179
+ :param func: the function to evaluate.
180
+ :param inputs: the argument sequence to pass to `func`.
181
+ :param params: a sequence of parameters `func` depends on but does not
182
+ explicitly take as arguments.
183
+ :param flag: if False, disable gradient checkpointing.
184
+ """
185
+ if flag:
186
+ args = tuple(inputs) + tuple(params)
187
+ return CheckpointFunction.apply(func, len(inputs), *args)
188
+ else:
189
+ return func(*inputs)
190
+
191
+
192
+ class CheckpointFunction(torch.autograd.Function):
193
+ @staticmethod
194
+ def forward(ctx, run_function, length, *args):
195
+ ctx.run_function = run_function
196
+ ctx.input_tensors = list(args[:length])
197
+ ctx.input_params = list(args[length:])
198
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
199
+ "dtype": torch.get_autocast_gpu_dtype(),
200
+ "cache_enabled": torch.is_autocast_cache_enabled()}
201
+ with torch.no_grad():
202
+ output_tensors = ctx.run_function(*ctx.input_tensors)
203
+ return output_tensors
204
+
205
+ @staticmethod
206
+ def backward(ctx, *output_grads):
207
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
208
+ with torch.enable_grad(), \
209
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
210
+ # Fixes a bug where the first op in run_function modifies the
211
+ # Tensor storage in place, which is not allowed for detach()'d
212
+ # Tensors.
213
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
214
+ output_tensors = ctx.run_function(*shallow_copies)
215
+ input_grads = torch.autograd.grad(
216
+ output_tensors,
217
+ ctx.input_tensors + ctx.input_params,
218
+ output_grads,
219
+ allow_unused=True,
220
+ )
221
+ del ctx.input_tensors
222
+ del ctx.input_params
223
+ del output_tensors
224
+ return (None, None) + input_grads
225
+
226
+
227
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
228
+ """
229
+ Create sinusoidal timestep embeddings.
230
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
231
+ These may be fractional.
232
+ :param dim: the dimension of the output.
233
+ :param max_period: controls the minimum frequency of the embeddings.
234
+ :return: an [N x dim] Tensor of positional embeddings.
235
+ """
236
+ if not repeat_only:
237
+ half = dim // 2
238
+ freqs = torch.exp(
239
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
240
+ )
241
+ args = timesteps[:, None].float() * freqs[None]
242
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
243
+ if dim % 2:
244
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
245
+ else:
246
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
247
+ return embedding
248
+
249
+
250
+ def zero_module(module):
251
+ """
252
+ Zero out the parameters of a module and return it.
253
+ """
254
+ for p in module.parameters():
255
+ p.detach().zero_()
256
+ return module
257
+
258
+
259
+ def scale_module(module, scale):
260
+ """
261
+ Scale the parameters of a module and return it.
262
+ """
263
+ for p in module.parameters():
264
+ p.detach().mul_(scale)
265
+ return module
266
+
267
+
268
+ def mean_flat(tensor):
269
+ """
270
+ Take the mean over all non-batch dimensions.
271
+ """
272
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
273
+
274
+
275
+ def avg_pool_nd(dims, *args, **kwargs):
276
+ """
277
+ Create a 1D, 2D, or 3D average pooling module.
278
+ """
279
+ if dims == 1:
280
+ return nn.AvgPool1d(*args, **kwargs)
281
+ elif dims == 2:
282
+ return nn.AvgPool2d(*args, **kwargs)
283
+ elif dims == 3:
284
+ return nn.AvgPool3d(*args, **kwargs)
285
+ raise ValueError(f"unsupported dimensions: {dims}")
286
+
287
+
288
+ class HybridConditioner(nn.Module):
289
+
290
+ def __init__(self, c_concat_config, c_crossattn_config):
291
+ super().__init__()
292
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
293
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
294
+
295
+ def forward(self, c_concat, c_crossattn):
296
+ c_concat = self.concat_conditioner(c_concat)
297
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
298
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
299
+
300
+
301
+ def noise_like(shape, device, repeat=False):
302
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
303
+ noise = lambda: torch.randn(shape, device=device)
304
+ return repeat_noise() if repeat else noise()
ldm_patched/ldm/modules/distributions/__init__.py ADDED
File without changes
ldm_patched/ldm/modules/distributions/distributions.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
45
+ + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2, 3])
47
+ else:
48
+ return 0.5 * torch.sum(
49
+ torch.pow(self.mean - other.mean, 2) / other.var
50
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
51
+ dim=[1, 2, 3])
52
+
53
+ def nll(self, sample, dims=[1,2,3]):
54
+ if self.deterministic:
55
+ return torch.Tensor([0.])
56
+ logtwopi = np.log(2.0 * np.pi)
57
+ return 0.5 * torch.sum(
58
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59
+ dim=dims)
60
+
61
+ def mode(self):
62
+ return self.mean
63
+
64
+
65
+ def normal_kl(mean1, logvar1, mean2, logvar2):
66
+ """
67
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68
+ Compute the KL divergence between two gaussians.
69
+ Shapes are automatically broadcasted, so batches can be compared to
70
+ scalars, among other use cases.
71
+ """
72
+ tensor = None
73
+ for obj in (mean1, logvar1, mean2, logvar2):
74
+ if isinstance(obj, torch.Tensor):
75
+ tensor = obj
76
+ break
77
+ assert tensor is not None, "at least one argument must be a Tensor"
78
+
79
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
80
+ # Tensors, but it does not work for torch.exp().
81
+ logvar1, logvar2 = [
82
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83
+ for x in (logvar1, logvar2)
84
+ ]
85
+
86
+ return 0.5 * (
87
+ -1.0
88
+ + logvar2
89
+ - logvar1
90
+ + torch.exp(logvar1 - logvar2)
91
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92
+ )
ldm_patched/ldm/modules/ema.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def reset_num_updates(self):
26
+ del self.num_updates
27
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
47
+ else:
48
+ assert not key in self.m_name2s_name
49
+
50
+ def copy_to(self, model):
51
+ m_param = dict(model.named_parameters())
52
+ shadow_params = dict(self.named_buffers())
53
+ for key in m_param:
54
+ if m_param[key].requires_grad:
55
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
56
+ else:
57
+ assert not key in self.m_name2s_name
58
+
59
+ def store(self, parameters):
60
+ """
61
+ Save the current parameters for restoring later.
62
+ Args:
63
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
64
+ temporarily stored.
65
+ """
66
+ self.collected_params = [param.clone() for param in parameters]
67
+
68
+ def restore(self, parameters):
69
+ """
70
+ Restore the parameters stored with the `store` method.
71
+ Useful to validate the model with EMA parameters without affecting the
72
+ original optimization process. Store the parameters before the
73
+ `copy_to` method. After validation (or model saving), use this to
74
+ restore the former parameters.
75
+ Args:
76
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
77
+ updated with the stored parameters.
78
+ """
79
+ for c_param, param in zip(self.collected_params, parameters):
80
+ param.data.copy_(c_param.data)
ldm_patched/ldm/modules/encoders/__init__.py ADDED
File without changes