Hirokusa commited on
Commit
44c5599
1 Parent(s): 2ebf309

Upload sd_hijack.py

Browse files
Files changed (1) hide show
  1. sd_hijack.py +195 -0
sd_hijack.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+ import torch
6
+ import numpy as np
7
+ from torch import einsum
8
+ from torch.nn.functional import silu
9
+
10
+ import modules.textual_inversion.textual_inversion
11
+ from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
12
+ from modules.hypernetworks import hypernetwork
13
+ from modules.shared import opts, device, cmd_opts
14
+ from modules import sd_hijack_clip, sd_hijack_open_clip
15
+
16
+ from modules.sd_hijack_optimizations import invokeAI_mps_available
17
+
18
+ import ldm.modules.attention
19
+ import ldm.modules.diffusionmodules.model
20
+ import ldm.models.diffusion.ddim
21
+ import ldm.models.diffusion.plms
22
+ import ldm.modules.encoders.modules
23
+
24
+ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
25
+ diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
26
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
27
+
28
+ # new memory efficient cross attention blocks do not support hypernets and we already
29
+ # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
30
+ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
31
+ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
32
+
33
+ # silence new console spam from SD2
34
+ ldm.modules.attention.print = lambda *args: None
35
+ ldm.modules.diffusionmodules.model.print = lambda *args: None
36
+
37
+ def apply_optimizations():
38
+ undo_optimizations()
39
+
40
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
41
+
42
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
43
+ print("Applying xformers cross attention optimization.")
44
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
45
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
46
+ elif cmd_opts.opt_split_attention_v1:
47
+ print("Applying v1 cross attention optimization.")
48
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
49
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
50
+ if not invokeAI_mps_available and shared.device.type == 'mps':
51
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
52
+ print("Applying v1 cross attention optimization.")
53
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
54
+ else:
55
+ print("Applying cross attention optimization (InvokeAI).")
56
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
57
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
58
+ print("Applying cross attention optimization (Doggettx).")
59
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
60
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
61
+
62
+
63
+ def undo_optimizations():
64
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
65
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
66
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
67
+
68
+
69
+ def fix_checkpoint():
70
+ ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
71
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
72
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
73
+
74
+ class StableDiffusionModelHijack:
75
+ fixes = None
76
+ comments = []
77
+ layers = None
78
+ circular_enabled = False
79
+ clip = None
80
+
81
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
82
+
83
+ def hijack(self, m):
84
+ if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
85
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
86
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
87
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
88
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
89
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
90
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
91
+
92
+ self.clip = m.cond_stage_model
93
+
94
+ apply_optimizations()
95
+ fix_checkpoint()
96
+
97
+ def flatten(el):
98
+ flattened = [flatten(children) for children in el.children()]
99
+ res = [el]
100
+ for c in flattened:
101
+ res += c
102
+ return res
103
+
104
+ self.layers = flatten(m)
105
+
106
+ def undo_hijack(self, m):
107
+ if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
108
+ m.cond_stage_model = m.cond_stage_model.wrapped
109
+
110
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
111
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
112
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
113
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
114
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
115
+ m.cond_stage_model = m.cond_stage_model.wrapped
116
+
117
+ self.apply_circular(False)
118
+ self.layers = None
119
+ self.clip = None
120
+
121
+ def apply_circular(self, enable):
122
+ if self.circular_enabled == enable:
123
+ return
124
+
125
+ self.circular_enabled = enable
126
+
127
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
128
+ layer.padding_mode = 'circular' if enable else 'zeros'
129
+
130
+ def clear_comments(self):
131
+ self.comments = []
132
+
133
+ def tokenize(self, text):
134
+ _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
135
+ return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
136
+
137
+
138
+
139
+ class EmbeddingsWithFixes(torch.nn.Module):
140
+ def __init__(self, wrapped, embeddings):
141
+ super().__init__()
142
+ self.wrapped = wrapped
143
+ self.embeddings = embeddings
144
+
145
+ def forward(self, input_ids):
146
+ batch_fixes = self.embeddings.fixes
147
+ self.embeddings.fixes = None
148
+
149
+ inputs_embeds = self.wrapped(input_ids)
150
+
151
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
152
+ return inputs_embeds
153
+
154
+ vecs = []
155
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
156
+ for offset, embedding in fixes:
157
+ emb = embedding.vec
158
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
159
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
160
+
161
+ vecs.append(tensor)
162
+
163
+ return torch.stack(vecs)
164
+
165
+
166
+ def add_circular_option_to_conv_2d():
167
+ conv2d_constructor = torch.nn.Conv2d.__init__
168
+
169
+ def conv2d_constructor_circular(self, *args, **kwargs):
170
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
171
+
172
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
173
+
174
+
175
+ model_hijack = StableDiffusionModelHijack()
176
+
177
+
178
+ def register_buffer(self, name, attr):
179
+ """
180
+ Fix register buffer bug for Mac OS.
181
+ """
182
+
183
+ if type(attr) == torch.Tensor:
184
+ if attr.device != devices.device:
185
+
186
+ if devices.has_mps():
187
+ attr = attr.to(device="mps", dtype=torch.float32)
188
+ else:
189
+ attr = attr.to(devices.device)
190
+
191
+ setattr(self, name, attr)
192
+
193
+
194
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
195
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer