nyanko7 commited on
Commit
1e66485
1 Parent(s): 0aa5cc5

feat: requirements, cleanup

Browse files
Files changed (6) hide show
  1. Dockerfile +22 -0
  2. README.md +6 -6
  3. app.py +77 -28
  4. modules/lora.py +181 -0
  5. modules/model.py +0 -144
  6. requirements.txt +0 -8
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile Public T4
2
+
3
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
4
+ ENV DEBIAN_FRONTEND noninteractive
5
+
6
+ WORKDIR /content
7
+
8
+ RUN apt-get update -y && apt-get upgrade -y && apt-get install -y libgl1 libglib2.0-0 wget git git-lfs python3-pip python-is-python3 && pip3 install --upgrade pip
9
+
10
+ RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchsde --extra-index-url https://download.pytorch.org/whl/cu113
11
+ RUN pip install https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.16/xformers-0.0.16+814314d.d20230118-cp310-cp310-linux_x86_64.whl
12
+ RUN pip install --pre triton
13
+ RUN pip install numexpr einops diffusers transformers k_diffusion safetensors gradio
14
+
15
+ ADD . .
16
+ RUN adduser --disabled-password --gecos '' user
17
+ RUN chown -R user:user /content
18
+ RUN chmod -R 777 /content
19
+ USER user
20
+
21
+ EXPOSE 7860
22
+ CMD python /content/app.py
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Sd Diffusers Webui
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.16.2
8
- app_file: app.py
9
  pinned: false
10
  license: openrail
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
1
  ---
2
  title: Sd Diffusers Webui
3
+ emoji: 🐳
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: docker
7
+ sdk_version: 3.9
 
8
  pinned: false
9
  license: openrail
10
+ app_port: 7860
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  import gradio as gr
5
  import numpy as np
6
  import torch
 
7
 
8
  from gradio import inputs
9
  from diffusers import (
@@ -14,7 +15,6 @@ from diffusers import (
14
  from modules.model import (
15
  CrossAttnProcessor,
16
  StableDiffusionPipeline,
17
- load_lora_attn_procs,
18
  )
19
  from torchvision import transforms
20
  from transformers import CLIPTokenizer, CLIPTextModel
@@ -22,16 +22,17 @@ from PIL import Image
22
  from pathlib import Path
23
  from safetensors.torch import load_file
24
  import modules.safe as _
 
25
 
26
  models = [
27
- ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF"),
28
- ("Basil Mix", "nuigurumi/basil_mix"),
29
- ("Pastal Mix", "andite/pastel-mix"),
30
- ("ACertainModel", "JosephusCheung/ACertainModel"),
 
31
  ]
32
 
33
- base_name, base_model = models[0]
34
- clip_skip = 2
35
 
36
  samplers_k_diffusion = [
37
  ("Euler a", "sample_euler_ancestral", {}),
@@ -103,6 +104,10 @@ unet_cache = {
103
  base_name: unet
104
  }
105
 
 
 
 
 
106
  def get_model(name):
107
  keys = [k[0] for k in models]
108
  if name not in unet_cache:
@@ -114,11 +119,21 @@ def get_model(name):
114
  subfolder="unet",
115
  torch_dtype=torch.float16,
116
  )
 
 
 
117
  unet_cache[name] = unet
 
118
 
119
  g_unet = unet_cache[name]
120
- g_unet.set_attn_processor(None)
121
- return g_unet
 
 
 
 
 
 
122
 
123
  def error_str(error, title="Error"):
124
  return (
@@ -129,18 +144,46 @@ def error_str(error, title="Error"):
129
  )
130
 
131
 
132
- te_base_weight = text_encoder.get_input_embeddings().weight.data.detach().clone()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  def restore_all():
136
  global te_base_weight, tokenizer
137
- text_encoder.get_input_embeddings().weight.data = te_base_weight
 
 
 
138
  tokenizer = CLIPTokenizer.from_pretrained(
139
  base_model,
140
  subfolder="tokenizer",
141
  torch_dtype=torch.float16,
142
  )
143
 
 
 
 
 
 
 
 
 
144
 
145
  def inference(
146
  prompt,
@@ -167,13 +210,15 @@ def inference(
167
  global pipe, unet, tokenizer, text_encoder
168
  if seed is None or seed == 0:
169
  seed = random.randint(0, 2147483647)
 
 
 
170
  generator = torch.Generator("cuda").manual_seed(int(seed))
171
 
172
- local_unet = get_model(model)
173
  if lora_state is not None and lora_state != "":
174
- load_lora_attn_procs(lora_state, local_unet, lora_scale)
175
- else:
176
- local_unet.set_attn_processor(CrossAttnProcessor())
177
 
178
  pipe.setup_unet(local_unet)
179
  sampler_name, sampler_opt = None, None
@@ -182,23 +227,23 @@ def inference(
182
  sampler_name, sampler_opt = funcname, options
183
 
184
  if embs is not None and len(embs) > 0:
185
- delta_weight = []
186
  for name, file in embs.items():
187
  if str(file).endswith(".pt"):
188
  loaded_learned_embeds = torch.load(file, map_location="cpu")
189
  else:
190
  loaded_learned_embeds = load_file(file, device="cpu")
191
  loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
192
- added_length = tokenizer.add_tokens(name)
193
 
194
- assert added_length == loaded_learned_embeds.shape[0]
195
- delta_weight.append(loaded_learned_embeds)
 
 
196
 
197
- delta_weight = torch.cat(delta_weight, dim=0)
198
- text_encoder.resize_token_embeddings(len(tokenizer))
199
- text_encoder.get_input_embeddings().weight.data[
200
- -delta_weight.shape[0] :
201
- ] = delta_weight
202
 
203
  config = {
204
  "negative_prompt": neg_prompt,
@@ -234,6 +279,10 @@ def inference(
234
  # restore
235
  if embs is not None and len(embs) > 0:
236
  restore_all()
 
 
 
 
237
  return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
238
 
239
 
@@ -513,7 +562,7 @@ with gr.Blocks(css=css) as demo:
513
  label="Guidance scale", value=7.5, maximum=15
514
  )
515
  steps = gr.Slider(
516
- label="Steps", value=25, minimum=2, maximum=75, step=1
517
  )
518
 
519
  with gr.Row():
@@ -704,7 +753,7 @@ with gr.Blocks(css=css) as demo:
704
  step=0.01,
705
  value=0.5,
706
  )
707
-
708
 
709
  sk_update.click(
710
  detect_text,
@@ -739,7 +788,7 @@ with gr.Blocks(css=css) as demo:
739
  source="upload",
740
  shape=(512, 512),
741
  )
742
-
743
  mask_outsides2 = gr.Checkbox(
744
  label="Mask other areas",
745
  value=False
@@ -803,4 +852,4 @@ with gr.Blocks(css=css) as demo:
803
 
804
  print(f"Space built in {time.time() - start_time:.2f} seconds")
805
  # demo.launch(share=True)
806
- demo.launch()
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
+ import math
8
 
9
  from gradio import inputs
10
  from diffusers import (
15
  from modules.model import (
16
  CrossAttnProcessor,
17
  StableDiffusionPipeline,
 
18
  )
19
  from torchvision import transforms
20
  from transformers import CLIPTokenizer, CLIPTextModel
22
  from pathlib import Path
23
  from safetensors.torch import load_file
24
  import modules.safe as _
25
+ from modules.lora import LoRANetwork
26
 
27
  models = [
28
+ # format: name, model_path, clip_skip
29
+ ("AbyssOrangeMix2", "Korakoe/AbyssOrangeMix2-HF", 2),
30
+ ("Basil Mix", "nuigurumi/basil_mix", 2),
31
+ ("Pastal Mix", "andite/pastel-mix", 2),
32
+ ("ACertainModel", "JosephusCheung/ACertainModel", 2),
33
  ]
34
 
35
+ base_name, base_model, clip_skip = models[0]
 
36
 
37
  samplers_k_diffusion = [
38
  ("Euler a", "sample_euler_ancestral", {}),
104
  base_name: unet
105
  }
106
 
107
+ lora_cache = {
108
+ base_name: LoRANetwork(text_encoder, unet)
109
+ }
110
+
111
  def get_model(name):
112
  keys = [k[0] for k in models]
113
  if name not in unet_cache:
119
  subfolder="unet",
120
  torch_dtype=torch.float16,
121
  )
122
+ if torch.cuda.is_available():
123
+ unet.to("cuda")
124
+
125
  unet_cache[name] = unet
126
+ lora_cache[name] = LoRANetwork(lora_cache[base_name].text_encoder_loras, unet)
127
 
128
  g_unet = unet_cache[name]
129
+ g_lora = lora_cache[name]
130
+ g_unet.set_attn_processor(CrossAttnProcessor())
131
+ g_lora.reset()
132
+ return g_unet, g_lora
133
+
134
+ # precache on huggingface
135
+ # for model in get_model_list():
136
+ # get_model(model[0])
137
 
138
  def error_str(error, title="Error"):
139
  return (
144
  )
145
 
146
 
147
+ te_base_weight_length = text_encoder.get_input_embeddings().weight.data.shape[0]
148
+ original_prepare_for_tokenization = tokenizer.prepare_for_tokenization
149
+
150
+ def make_token_names(embs):
151
+ all_tokens = []
152
+ for name, vec in embs.items():
153
+ tokens = [f'emb-{name}-{i}' for i in range(len(vec))]
154
+ all_tokens.append(tokens)
155
+ return all_tokens
156
+
157
+ def setup_tokenizer(embs):
158
+ reg_match = [re.compile(fr"(?:^|(?<=\s|,)){k}(?=,|\s|$)") for k in embs.keys()]
159
+ clip_keywords = [' '.join(s) for s in make_token_names(embs)]
160
+
161
+ def parse_prompt(prompt: str):
162
+ for m, v in zip(reg_match, clip_keywords):
163
+ prompt = m.sub(v, prompt)
164
+ return prompt
165
 
166
 
167
  def restore_all():
168
  global te_base_weight, tokenizer
169
+ tokenizer.prepare_for_tokenization = original_prepare_for_tokenization
170
+
171
+ embeddings = text_encoder.get_input_embeddings()
172
+ text_encoder.get_input_embeddings().weight.data = embeddings.weight.data[:te_base_weight_length]
173
  tokenizer = CLIPTokenizer.from_pretrained(
174
  base_model,
175
  subfolder="tokenizer",
176
  torch_dtype=torch.float16,
177
  )
178
 
179
+ def convert_size(size_bytes):
180
+ if size_bytes == 0:
181
+ return "0B"
182
+ size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
183
+ i = int(math.floor(math.log(size_bytes, 1024)))
184
+ p = math.pow(1024, i)
185
+ s = round(size_bytes / p, 2)
186
+ return "%s %s" % (s, size_name[i])
187
 
188
  def inference(
189
  prompt,
210
  global pipe, unet, tokenizer, text_encoder
211
  if seed is None or seed == 0:
212
  seed = random.randint(0, 2147483647)
213
+
214
+ start_time = time.time()
215
+ restore_all()
216
  generator = torch.Generator("cuda").manual_seed(int(seed))
217
 
218
+ local_unet, local_lora = get_model(model)
219
  if lora_state is not None and lora_state != "":
220
+ local_lora.load(lora_state, lora_scale)
221
+ local_lora.to(local_unet.device, dtype=local_unet.dtype)
 
222
 
223
  pipe.setup_unet(local_unet)
224
  sampler_name, sampler_opt = None, None
227
  sampler_name, sampler_opt = funcname, options
228
 
229
  if embs is not None and len(embs) > 0:
230
+ ti_embs = {}
231
  for name, file in embs.items():
232
  if str(file).endswith(".pt"):
233
  loaded_learned_embeds = torch.load(file, map_location="cpu")
234
  else:
235
  loaded_learned_embeds = load_file(file, device="cpu")
236
  loaded_learned_embeds = loaded_learned_embeds["string_to_param"]["*"]
237
+ ti_embs[name] = loaded_learned_embeds
238
 
239
+ if len(ti_embs) > 0:
240
+ tokens = setup_tokenizer(ti_embs)
241
+ added_tokens = tokenizer.add_tokens(tokens)
242
+ delta_weight = torch.cat([val for val in ti_embs.values()], dim=0)
243
 
244
+ assert added_tokens == delta_weight.shape[0]
245
+ text_encoder.resize_token_embeddings(len(tokenizer))
246
+ text_encoder.get_input_embeddings().weight.data[-delta_weight.shape[0]:] = delta_weight
 
 
247
 
248
  config = {
249
  "negative_prompt": neg_prompt,
279
  # restore
280
  if embs is not None and len(embs) > 0:
281
  restore_all()
282
+
283
+ end_time = time.time()
284
+ vram_free, vram_total = torch.cuda.mem_get_info()
285
+ print(f"done: res={width}x{height}, step={steps}, time={round(end_time-start_time, 2)}s, vram_alloc={convert_size(vram_total-vram_free)}/{convert_size(vram_total)}")
286
  return gr.Image.update(result[0][0], label=f"Initial Seed: {seed}")
287
 
288
 
562
  label="Guidance scale", value=7.5, maximum=15
563
  )
564
  steps = gr.Slider(
565
+ label="Steps", value=25, minimum=2, maximum=50, step=1
566
  )
567
 
568
  with gr.Row():
753
  step=0.01,
754
  value=0.5,
755
  )
756
+
757
 
758
  sk_update.click(
759
  detect_text,
788
  source="upload",
789
  shape=(512, 512),
790
  )
791
+
792
  mask_outsides2 = gr.Checkbox(
793
  label="Mask other areas",
794
  value=False
852
 
853
  print(f"Space built in {time.time() - start_time:.2f} seconds")
854
  # demo.launch(share=True)
855
+ demo.launch(enable_queue=True, server_name="0.0.0.0", server_port=7860)
modules/lora.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss/blob/master/networks/lora.py#L48
6
+
7
+ import math
8
+ import os
9
+ import torch
10
+ import modules.safe as _
11
+ from safetensors.torch import load_file
12
+
13
+
14
+ class LoRAModule(torch.nn.Module):
15
+ """
16
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ lora_name,
22
+ org_module: torch.nn.Module,
23
+ multiplier=1.0,
24
+ lora_dim=4,
25
+ alpha=1,
26
+ ):
27
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
28
+ super().__init__()
29
+ self.lora_name = lora_name
30
+ self.lora_dim = lora_dim
31
+
32
+ if org_module.__class__.__name__ == "Conv2d":
33
+ in_dim = org_module.in_channels
34
+ out_dim = org_module.out_channels
35
+ self.lora_down = torch.nn.Conv2d(in_dim, lora_dim, (1, 1), bias=False)
36
+ self.lora_up = torch.nn.Conv2d(lora_dim, out_dim, (1, 1), bias=False)
37
+ else:
38
+ in_dim = org_module.in_features
39
+ out_dim = org_module.out_features
40
+ self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
41
+ self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
42
+
43
+ if type(alpha) == torch.Tensor:
44
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
45
+
46
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
47
+ self.scale = alpha / self.lora_dim
48
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
49
+
50
+ # same as microsoft's
51
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
52
+ torch.nn.init.zeros_(self.lora_up.weight)
53
+
54
+ self.multiplier = multiplier
55
+ self.org_module = org_module # remove in applying
56
+ self.enable = False
57
+
58
+ def resize(self, rank, alpha):
59
+ self.alpha = torch.tensor(alpha)
60
+ self.scale = alpha / rank
61
+ if self.lora_down.__class__.__name__ == "Conv2d":
62
+ in_dim = self.lora_down.in_channels
63
+ out_dim = self.lora_up.out_channels
64
+ self.lora_down = torch.nn.Conv2d(in_dim, rank, (1, 1), bias=False)
65
+ self.lora_up = torch.nn.Conv2d(rank, out_dim, (1, 1), bias=False)
66
+ else:
67
+ in_dim = self.lora_down.in_features
68
+ out_dim = self.lora_up.out_features
69
+ self.lora_down = torch.nn.Linear(in_dim, rank, bias=False)
70
+ self.lora_up = torch.nn.Linear(rank, out_dim, bias=False)
71
+
72
+ def apply(self):
73
+ if hasattr(self, "org_module"):
74
+ self.org_forward = self.org_module.forward
75
+ self.org_module.forward = self.forward
76
+ del self.org_module
77
+
78
+ def forward(self, x):
79
+ if self.enable:
80
+ return (
81
+ self.org_forward(x)
82
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
83
+ )
84
+ return self.org_forward(x)
85
+
86
+
87
+ class LoRANetwork(torch.nn.Module):
88
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
89
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
90
+ LORA_PREFIX_UNET = "lora_unet"
91
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
92
+
93
+ def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
94
+ super().__init__()
95
+ self.multiplier = multiplier
96
+ self.lora_dim = lora_dim
97
+ self.alpha = alpha
98
+
99
+ # create module instances
100
+ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules):
101
+ loras = []
102
+ for name, module in root_module.named_modules():
103
+ if module.__class__.__name__ in target_replace_modules:
104
+ for child_name, child_module in module.named_modules():
105
+ if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
106
+ lora_name = prefix + "." + name + "." + child_name
107
+ lora_name = lora_name.replace(".", "_")
108
+ lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha,)
109
+ loras.append(lora)
110
+ return loras
111
+
112
+ if isinstance(text_encoder, list):
113
+ self.text_encoder_loras = text_encoder
114
+ else:
115
+ self.text_encoder_loras = create_modules(LoRANetwork.LORA_PREFIX_TEXT_ENCODER, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
116
+ print(f"Create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
117
+
118
+ self.unet_loras = create_modules(LoRANetwork.LORA_PREFIX_UNET, unet, LoRANetwork.UNET_TARGET_REPLACE_MODULE)
119
+ print(f"Create LoRA for U-Net: {len(self.unet_loras)} modules.")
120
+
121
+ self.weights_sd = None
122
+
123
+ # assertion
124
+ names = set()
125
+ for lora in self.text_encoder_loras + self.unet_loras:
126
+ assert (lora.lora_name not in names), f"duplicated lora name: {lora.lora_name}"
127
+ names.add(lora.lora_name)
128
+
129
+ lora.apply()
130
+ self.add_module(lora.lora_name, lora)
131
+
132
+ def reset(self):
133
+ for lora in self.text_encoder_loras + self.unet_loras:
134
+ lora.enable = False
135
+
136
+ def load(self, file, scale):
137
+
138
+ weights = None
139
+ if os.path.splitext(file)[1] == ".safetensors":
140
+ weights = load_file(file)
141
+ else:
142
+ weights = torch.load(file, map_location="cpu")
143
+
144
+ if not weights:
145
+ return
146
+
147
+ network_alpha = None
148
+ network_dim = None
149
+ for key, value in weights.items():
150
+ if network_alpha is None and "alpha" in key:
151
+ network_alpha = value
152
+ if network_dim is None and "lora_down" in key and len(value.size()) == 2:
153
+ network_dim = value.size()[0]
154
+
155
+ if network_alpha is None:
156
+ network_alpha = network_dim
157
+
158
+ weights_has_text_encoder = weights_has_unet = False
159
+ weights_to_modify = []
160
+
161
+ for key in weights.keys():
162
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
163
+ weights_has_text_encoder = True
164
+
165
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET):
166
+ weights_has_unet = True
167
+
168
+ if weights_has_text_encoder:
169
+ weights_to_modify += self.text_encoder_loras
170
+
171
+ if weights_has_unet:
172
+ weights_to_modify += self.unet_loras
173
+
174
+ for lora in self.text_encoder_loras + self.unet_loras:
175
+ lora.resize(network_dim, network_alpha)
176
+ if lora in weights_to_modify:
177
+ lora.enable = True
178
+
179
+ info = self.load_state_dict(weights, False)
180
+ print(f"Weights are loaded. Unexpect keys={info.unexpected_keys}")
181
+
modules/model.py CHANGED
@@ -68,79 +68,6 @@ def get_attention_scores(attn, query, key, attention_mask=None):
68
  return attention_scores
69
 
70
 
71
- def load_lora_attn_procs(model_file, unet, scale=1.0):
72
-
73
- if Path(model_file).suffix == ".pt":
74
- state_dict = torch.load(model_file, map_location="cpu")
75
- else:
76
- state_dict = load_file(model_file, device="cpu")
77
-
78
- if any("lora_unet_down_blocks" in k for k in state_dict.keys()):
79
- # convert ldm format lora
80
- df_lora = {}
81
- attn_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).lora_")
82
- alpha_numlayer = re.compile(r"_attn(\d)_to_([qkv]|out).alpha")
83
- for k, v in state_dict.items():
84
- if "attn" not in k or "lora_te" in k:
85
- # currently not support: ff, clip-attn
86
- continue
87
- k = k.replace("lora_unet_down_blocks_", "down_blocks.")
88
- k = k.replace("lora_unet_up_blocks_", "up_blocks.")
89
- k = k.replace("lora_unet_mid_block_", "mid_block_")
90
- k = k.replace("_attentions_", ".attentions.")
91
- k = k.replace("_transformer_blocks_", ".transformer_blocks.")
92
- k = k.replace("to_out_0", "to_out")
93
- k = attn_numlayer.sub(r".attn\1.processor.to_\2_lora.", k)
94
- k = alpha_numlayer.sub(r".attn\1.processor.to_\2_lora.alpha", k)
95
- df_lora[k] = v
96
- state_dict = df_lora
97
-
98
- # fill attn processors
99
- attn_processors = {}
100
-
101
- is_lora = all("lora" in k for k in state_dict.keys())
102
-
103
- if is_lora:
104
- lora_grouped_dict = defaultdict(dict)
105
- for key, value in state_dict.items():
106
- if "alpha" in key:
107
- attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(
108
- key.split(".")[-2:]
109
- )
110
- else:
111
- attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(
112
- key.split(".")[-3:]
113
- )
114
- lora_grouped_dict[attn_processor_key][sub_key] = value
115
-
116
- for key, value_dict in lora_grouped_dict.items():
117
- rank = value_dict["to_k_lora.down.weight"].shape[0]
118
- cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
119
- hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
120
-
121
- attn_processors[key] = LoRACrossAttnProcessor(
122
- hidden_size=hidden_size,
123
- cross_attention_dim=cross_attention_dim,
124
- rank=rank,
125
- scale=scale,
126
- )
127
- attn_processors[key].load_state_dict(value_dict, strict=False)
128
-
129
- else:
130
- raise ValueError(
131
- f"{model_file} does not seem to be in the correct format expected by LoRA training."
132
- )
133
-
134
- # set correct dtype & device
135
- attn_processors = {
136
- k: v.to(device=unet.device, dtype=unet.dtype)
137
- for k, v in attn_processors.items()
138
- }
139
-
140
- # set layers
141
- unet.set_attn_processor(attn_processors)
142
-
143
-
144
  class CrossAttnProcessor(nn.Module):
145
  def __call__(
146
  self,
@@ -148,7 +75,6 @@ class CrossAttnProcessor(nn.Module):
148
  hidden_states,
149
  encoder_hidden_states=None,
150
  attention_mask=None,
151
- qkvo_bias=None,
152
  ):
153
  batch_size, sequence_length, _ = hidden_states.shape
154
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
@@ -166,11 +92,6 @@ class CrossAttnProcessor(nn.Module):
166
  key = attn.to_k(encoder_states)
167
  value = attn.to_v(encoder_states)
168
 
169
- if qkvo_bias is not None:
170
- query += qkvo_bias["q"](hidden_states)
171
- key += qkvo_bias["k"](encoder_states)
172
- value += qkvo_bias["v"](encoder_states)
173
-
174
  query = attn.head_to_batch_dim(query)
175
  key = attn.head_to_batch_dim(key)
176
  value = attn.head_to_batch_dim(value)
@@ -219,76 +140,11 @@ class CrossAttnProcessor(nn.Module):
219
  # linear proj
220
  hidden_states = attn.to_out[0](hidden_states)
221
 
222
- if qkvo_bias is not None:
223
- hidden_states += qkvo_bias["o"](hidden_states)
224
-
225
  # dropout
226
  hidden_states = attn.to_out[1](hidden_states)
227
 
228
  return hidden_states
229
 
230
-
231
- class LoRACrossAttnProcessor(CrossAttnProcessor):
232
- def __init__(self, hidden_size, cross_attention_dim=None, rank=4, scale=1.0):
233
- super().__init__()
234
-
235
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
236
- self.to_k_lora = LoRALinearLayer(
237
- cross_attention_dim or hidden_size, hidden_size, rank
238
- )
239
- self.to_v_lora = LoRALinearLayer(
240
- cross_attention_dim or hidden_size, hidden_size, rank
241
- )
242
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
243
- self.scale = scale
244
-
245
- def __call__(
246
- self,
247
- attn,
248
- hidden_states,
249
- encoder_hidden_states=None,
250
- attention_mask=None,
251
- ):
252
- scale = self.scale
253
- qkvo_bias = {
254
- "q": lambda inputs: scale * self.to_q_lora(inputs),
255
- "k": lambda inputs: scale * self.to_k_lora(inputs),
256
- "v": lambda inputs: scale * self.to_v_lora(inputs),
257
- "o": lambda inputs: scale * self.to_out_lora(inputs),
258
- }
259
- return super().__call__(
260
- attn, hidden_states, encoder_hidden_states, attention_mask, qkvo_bias
261
- )
262
-
263
-
264
- class LoRALinearLayer(nn.Module):
265
- def __init__(self, in_features, out_features, rank=4):
266
- super().__init__()
267
-
268
- if rank > min(in_features, out_features):
269
- raise ValueError(
270
- f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}"
271
- )
272
-
273
- self.down = nn.Linear(in_features, rank, bias=False)
274
- self.up = nn.Linear(rank, out_features, bias=False)
275
- self.scale = 1.0
276
- self.alpha = rank
277
-
278
- nn.init.normal_(self.down.weight, std=1 / rank)
279
- nn.init.zeros_(self.up.weight)
280
-
281
- def forward(self, hidden_states):
282
- orig_dtype = hidden_states.dtype
283
- dtype = self.down.weight.dtype
284
- rank = self.down.out_features
285
-
286
- down_hidden_states = self.down(hidden_states.to(dtype))
287
- up_hidden_states = self.up(down_hidden_states) * (self.alpha / rank)
288
-
289
- return up_hidden_states.to(orig_dtype)
290
-
291
-
292
  class ModelWrapper:
293
  def __init__(self, model, alphas_cumprod):
294
  self.model = model
68
  return attention_scores
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  class CrossAttnProcessor(nn.Module):
72
  def __call__(
73
  self,
75
  hidden_states,
76
  encoder_hidden_states=None,
77
  attention_mask=None,
 
78
  ):
79
  batch_size, sequence_length, _ = hidden_states.shape
80
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
92
  key = attn.to_k(encoder_states)
93
  value = attn.to_v(encoder_states)
94
 
 
 
 
 
 
95
  query = attn.head_to_batch_dim(query)
96
  key = attn.head_to_batch_dim(key)
97
  value = attn.head_to_batch_dim(value)
140
  # linear proj
141
  hidden_states = attn.to_out[0](hidden_states)
142
 
 
 
 
143
  # dropout
144
  hidden_states = attn.to_out[1](hidden_states)
145
 
146
  return hidden_states
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  class ModelWrapper:
149
  def __init__(self, model, alphas_cumprod):
150
  self.model = model
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- torch
2
- einops
3
- diffusers
4
- transformers
5
- k_diffusion
6
- safetensors
7
- gradio
8
- torch