Hirokusa commited on
Commit
efb1e2a
1 Parent(s): 6afc7c5

Upload sd_hijack.py

Browse files
Files changed (1) hide show
  1. sd_hijack.py +200 -0
sd_hijack.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import sys
4
+ import traceback
5
+ import torch
6
+ import numpy as np
7
+ from torch import einsum
8
+ from torch.nn.functional import silu
9
+
10
+ import modules.textual_inversion.textual_inversion
11
+ from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
12
+ from modules.hypernetworks import hypernetwork
13
+ from modules.shared import opts, device, cmd_opts
14
+ from modules import sd_hijack_clip, sd_hijack_open_clip
15
+
16
+ from modules.sd_hijack_optimizations import invokeAI_mps_available
17
+
18
+ import ldm.modules.attention
19
+ import ldm.modules.diffusionmodules.model
20
+ import ldm.models.diffusion.ddim
21
+ import ldm.models.diffusion.plms
22
+ import ldm.modules.encoders.modules
23
+
24
+ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
25
+ diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
26
+ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
27
+
28
+ # new memory efficient cross attention blocks do not support hypernets and we already
29
+ # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
30
+ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
31
+ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
32
+
33
+ # silence new console spam from SD2
34
+ ldm.modules.attention.print = lambda *args: None
35
+ ldm.modules.diffusionmodules.model.print = lambda *args: None
36
+
37
+ def apply_optimizations():
38
+ undo_optimizations()
39
+
40
+ ldm.modules.diffusionmodules.model.nonlinearity = silu
41
+
42
+ if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
43
+ print("Applying xformers cross attention optimization.")
44
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
45
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
46
+ elif cmd_opts.opt_split_attention_v1:
47
+ print("Applying v1 cross attention optimization.")
48
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
49
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
50
+ if not invokeAI_mps_available and shared.device.type == 'mps':
51
+ print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
52
+ print("Applying v1 cross attention optimization.")
53
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
54
+ else:
55
+ print("Applying cross attention optimization (InvokeAI).")
56
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
57
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
58
+ print("Applying cross attention optimization (Doggettx).")
59
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
60
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
61
+
62
+
63
+ def undo_optimizations():
64
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
65
+ ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
66
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
67
+
68
+
69
+ def fix_checkpoint():
70
+ ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
71
+ ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
72
+ ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward
73
+
74
+ class StableDiffusionModelHijack:
75
+ fixes = None
76
+ comments = []
77
+ layers = None
78
+ circular_enabled = False
79
+ clip = None
80
+
81
+ embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
82
+
83
+ def hijack(self, m, use_improved_clip=True):
84
+
85
+ if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
86
+ if use_improved_clip:
87
+ from transformers import CLIPTextModel, CLIPTokenizer
88
+ device = "cuda" if torch.cuda.is_available() else "cpu"
89
+ m.cond_stage_model.transformer = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14-336").to(device)
90
+ m.cond_stage_model.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14-336")
91
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
92
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
93
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
94
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
95
+ m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
96
+ m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
97
+ self.clip = m.cond_stage_model
98
+
99
+ apply_optimizations()
100
+ fix_checkpoint()
101
+
102
+ def flatten(el):
103
+ flattened = [flatten(children) for children in el.children()]
104
+ res = [el]
105
+ for c in flattened:
106
+ res += c
107
+ return res
108
+
109
+ self.layers = flatten(m)
110
+
111
+ def undo_hijack(self, m):
112
+ if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
113
+ m.cond_stage_model = m.cond_stage_model.wrapped
114
+
115
+ model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
116
+ if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
117
+ model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
118
+ elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:
119
+ m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped
120
+ m.cond_stage_model = m.cond_stage_model.wrapped
121
+
122
+ self.apply_circular(False)
123
+ self.layers = None
124
+ self.clip = None
125
+
126
+ def apply_circular(self, enable):
127
+ if self.circular_enabled == enable:
128
+ return
129
+
130
+ self.circular_enabled = enable
131
+
132
+ for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
133
+ layer.padding_mode = 'circular' if enable else 'zeros'
134
+
135
+ def clear_comments(self):
136
+ self.comments = []
137
+
138
+ def tokenize(self, text):
139
+ _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
140
+ return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
141
+
142
+
143
+
144
+ class EmbeddingsWithFixes(torch.nn.Module):
145
+ def __init__(self, wrapped, embeddings):
146
+ super().__init__()
147
+ self.wrapped = wrapped
148
+ self.embeddings = embeddings
149
+
150
+ def forward(self, input_ids):
151
+ batch_fixes = self.embeddings.fixes
152
+ self.embeddings.fixes = None
153
+
154
+ inputs_embeds = self.wrapped(input_ids)
155
+
156
+ if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
157
+ return inputs_embeds
158
+
159
+ vecs = []
160
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
161
+ for offset, embedding in fixes:
162
+ emb = embedding.vec
163
+ emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
164
+ tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
165
+
166
+ vecs.append(tensor)
167
+
168
+ return torch.stack(vecs)
169
+
170
+
171
+ def add_circular_option_to_conv_2d():
172
+ conv2d_constructor = torch.nn.Conv2d.__init__
173
+
174
+ def conv2d_constructor_circular(self, *args, **kwargs):
175
+ return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)
176
+
177
+ torch.nn.Conv2d.__init__ = conv2d_constructor_circular
178
+
179
+
180
+ model_hijack = StableDiffusionModelHijack()
181
+
182
+
183
+ def register_buffer(self, name, attr):
184
+ """
185
+ Fix register buffer bug for Mac OS.
186
+ """
187
+
188
+ if type(attr) == torch.Tensor:
189
+ if attr.device != devices.device:
190
+
191
+ if devices.has_mps():
192
+ attr = attr.to(device="mps", dtype=torch.float32)
193
+ else:
194
+ attr = attr.to(devices.device)
195
+
196
+ setattr(self, name, attr)
197
+
198
+
199
+ ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
200
+ ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer