File size: 6,947 Bytes
34097e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import List, Dict
import re
import torch
from modules import extra_networks, shared
re_AND = re.compile(r"\bAND\b")
def load_prompt_loras(prompt: str):
prompt_loras.clear()
subprompts = re_AND.split(prompt)
tmp_prompt_loras = []
for i, subprompt in enumerate(subprompts):
loras = {}
_, extra_network_data = extra_networks.parse_prompt(subprompt)
for params in extra_network_data['lora']:
name = params.items[0]
multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0
loras[name] = multiplier
tmp_prompt_loras.append(loras)
prompt_loras.extend(tmp_prompt_loras * num_batches)
def reset_counters():
global text_model_encoder_counter
global diffusion_model_counter
# reset counter to uc head
text_model_encoder_counter = -1
diffusion_model_counter = 0
def lora_forward(compvis_module, input, res):
global text_model_encoder_counter
global diffusion_model_counter
import lora
if len(lora.loaded_loras) == 0:
return res
lora_layer_name: str | None = getattr(compvis_module, 'lora_layer_name', None)
if lora_layer_name is None:
return res
num_loras = len(lora.loaded_loras)
if text_model_encoder_counter == -1:
text_model_encoder_counter = len(prompt_loras) * num_loras
# print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}")
for lora in lora.loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
patch = module.up(module.down(res))
else:
patch = module.up(module.down(input))
alpha = module.alpha / module.up.weight.shape[1] if module.alpha else 1.0
num_prompts = len(prompt_loras)
# print(f"lora.name={lora.name} lora.mul={lora.multiplier} alpha={alpha} pat.shape={patch.shape}")
if enabled:
if lora_layer_name.startswith("transformer_"): # "transformer_text_model_encoder_"
#
if 0 <= text_model_encoder_counter // num_loras < len(prompt_loras):
# c
loras = prompt_loras[text_model_encoder_counter // num_loras]
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"c #{text_model_encoder_counter // num_loras} lora.name={lora.name} mul={multiplier}")
res += multiplier * alpha * patch
else:
# uc
if opt_uc_text_model_encoder and lora.multiplier != 0.0:
# print(f"uc #{text_model_encoder_counter // num_loras} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
if lora_layer_name.endswith("_11_mlp_fc2"): # last lora_layer_name of text_model_encoder
text_model_encoder_counter += 1
# c1 c1 c2 c2 .. .. uc uc
if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras:
text_model_encoder_counter = 0
elif lora_layer_name.startswith("diffusion_model_"): # "diffusion_model_"
if res.shape[0] == num_batches * num_prompts + num_batches:
# tensor.shape[1] == uncond.shape[1]
tensor_off = 0
uncond_off = num_batches * num_prompts
for b in range(num_batches):
# c
for p, loras in enumerate(prompt_loras):
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"tensor #{b}.{p} lora.name={lora.name} mul={multiplier}")
res[tensor_off] += multiplier * alpha * patch[tensor_off]
tensor_off += 1
# uc
if opt_uc_diffusion_model and lora.multiplier != 0.0:
# print(f"uncond lora.name={lora.name} lora.mul={lora.multiplier}")
res[uncond_off] += lora.multiplier * alpha * patch[uncond_off]
uncond_off += 1
else:
# tensor.shape[1] != uncond.shape[1]
cur_num_prompts = res.shape[0]
base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts
if 0 <= base < len(prompt_loras):
# c
for off in range(cur_num_prompts):
loras = prompt_loras[base + off]
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name)
res[off] += multiplier * alpha * patch[off]
else:
# uc
if opt_uc_diffusion_model and lora.multiplier != 0.0:
# print(f"uc {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
if lora_layer_name.endswith("_11_1_proj_out"): # last lora_layer_name of diffusion_model
diffusion_model_counter += cur_num_prompts
# c1 c2 .. uc
if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras:
diffusion_model_counter = 0
else:
# default
if lora.multiplier != 0.0:
# print(f"default {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
else:
# default
if lora.multiplier != 0.0:
# print(f"DEFAULT {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
return res
def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
enabled = False
opt_uc_text_model_encoder = False
opt_uc_diffusion_model = False
verbose = True
num_batches: int = 0
prompt_loras: List[Dict[str, float]] = []
text_model_encoder_counter: int = -1
diffusion_model_counter: int = 0
|