Spaces:
Runtime error
Runtime error
File size: 4,303 Bytes
a4d7b31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import List
import torch
from safetensors import safe_open
from diffusers import StableDiffusionPipeline
from .lora import (
monkeypatch_or_replace_safeloras,
apply_learned_embed_in_clip,
set_lora_diag,
parse_safeloras_embeds,
)
def lora_join(lora_safetenors: list):
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
_total_metadata = {}
total_metadata = {}
total_tensor = {}
total_rank = 0
ranklist = []
for _metadata in metadatas:
rankset = []
for k, v in _metadata.items():
if k.endswith("rank"):
rankset.append(int(v))
assert len(set(rankset)) <= 1, "Rank should be the same per model"
if len(rankset) == 0:
rankset = [0]
total_rank += rankset[0]
_total_metadata.update(_metadata)
ranklist.append(rankset[0])
# remove metadata about tokens
for k, v in _total_metadata.items():
if v != "<embed>":
total_metadata[k] = v
tensorkeys = set()
for safelora in lora_safetenors:
tensorkeys.update(safelora.keys())
for keys in tensorkeys:
if keys.startswith("text_encoder") or keys.startswith("unet"):
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
is_down = keys.endswith("down")
if is_down:
_tensor = torch.cat(tensorset, dim=0)
assert _tensor.shape[0] == total_rank
else:
_tensor = torch.cat(tensorset, dim=1)
assert _tensor.shape[1] == total_rank
total_tensor[keys] = _tensor
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
total_metadata[keys_rank] = str(total_rank)
token_size_list = []
for idx, safelora in enumerate(lora_safetenors):
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
for jdx, token in enumerate(sorted(tokens)):
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
token_size_list.append(len(tokens))
return total_tensor, total_metadata, ranklist, token_size_list
class DummySafeTensorObject:
def __init__(self, tensor: dict, metadata):
self.tensor = tensor
self._metadata = metadata
def keys(self):
return self.tensor.keys()
def metadata(self):
return self._metadata
def get_tensor(self, key):
return self.tensor[key]
class LoRAManager:
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
self.lora_paths_list = lora_paths_list
self.pipe = pipe
self._setup()
def _setup(self):
self._lora_safetenors = [
safe_open(path, framework="pt", device="cpu")
for path in self.lora_paths_list
]
(
total_tensor,
total_metadata,
self.ranklist,
self.token_size_list,
) = lora_join(self._lora_safetenors)
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
tok_dict = parse_safeloras_embeds(self.total_safelora)
apply_learned_embed_in_clip(
tok_dict,
self.pipe.text_encoder,
self.pipe.tokenizer,
token=None,
idempotent=True,
)
def tune(self, scales):
assert len(scales) == len(
self.ranklist
), "Scale list should be the same length as ranklist"
diags = []
for scale, rank in zip(scales, self.ranklist):
diags = diags + [scale] * rank
set_lora_diag(self.pipe.unet, torch.tensor(diags))
def prompt(self, prompt):
if prompt is not None:
for idx, tok_size in enumerate(self.token_size_list):
prompt = prompt.replace(
f"<{idx + 1}>",
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
)
# TODO : Rescale LoRA + Text inputs based on prompt scale params
return prompt
|