Spaces:
Runtime error
Runtime error
from typing import List | |
import torch | |
from safetensors import safe_open | |
from diffusers import StableDiffusionPipeline | |
from .lora import ( | |
monkeypatch_or_replace_safeloras, | |
apply_learned_embed_in_clip, | |
set_lora_diag, | |
parse_safeloras_embeds, | |
) | |
def lora_join(lora_safetenors: list): | |
metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] | |
_total_metadata = {} | |
total_metadata = {} | |
total_tensor = {} | |
total_rank = 0 | |
ranklist = [] | |
for _metadata in metadatas: | |
rankset = [] | |
for k, v in _metadata.items(): | |
if k.endswith("rank"): | |
rankset.append(int(v)) | |
assert len(set(rankset)) <= 1, "Rank should be the same per model" | |
if len(rankset) == 0: | |
rankset = [0] | |
total_rank += rankset[0] | |
_total_metadata.update(_metadata) | |
ranklist.append(rankset[0]) | |
# remove metadata about tokens | |
for k, v in _total_metadata.items(): | |
if v != "<embed>": | |
total_metadata[k] = v | |
tensorkeys = set() | |
for safelora in lora_safetenors: | |
tensorkeys.update(safelora.keys()) | |
for keys in tensorkeys: | |
if keys.startswith("text_encoder") or keys.startswith("unet"): | |
tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] | |
is_down = keys.endswith("down") | |
if is_down: | |
_tensor = torch.cat(tensorset, dim=0) | |
assert _tensor.shape[0] == total_rank | |
else: | |
_tensor = torch.cat(tensorset, dim=1) | |
assert _tensor.shape[1] == total_rank | |
total_tensor[keys] = _tensor | |
keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" | |
total_metadata[keys_rank] = str(total_rank) | |
token_size_list = [] | |
for idx, safelora in enumerate(lora_safetenors): | |
tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"] | |
for jdx, token in enumerate(sorted(tokens)): | |
total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token) | |
total_metadata[f"<s{idx}-{jdx}>"] = "<embed>" | |
print(f"Embedding {token} replaced to <s{idx}-{jdx}>") | |
token_size_list.append(len(tokens)) | |
return total_tensor, total_metadata, ranklist, token_size_list | |
class DummySafeTensorObject: | |
def __init__(self, tensor: dict, metadata): | |
self.tensor = tensor | |
self._metadata = metadata | |
def keys(self): | |
return self.tensor.keys() | |
def metadata(self): | |
return self._metadata | |
def get_tensor(self, key): | |
return self.tensor[key] | |
class LoRAManager: | |
def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline): | |
self.lora_paths_list = lora_paths_list | |
self.pipe = pipe | |
self._setup() | |
def _setup(self): | |
self._lora_safetenors = [ | |
safe_open(path, framework="pt", device="cpu") | |
for path in self.lora_paths_list | |
] | |
( | |
total_tensor, | |
total_metadata, | |
self.ranklist, | |
self.token_size_list, | |
) = lora_join(self._lora_safetenors) | |
self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata) | |
monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora) | |
tok_dict = parse_safeloras_embeds(self.total_safelora) | |
apply_learned_embed_in_clip( | |
tok_dict, | |
self.pipe.text_encoder, | |
self.pipe.tokenizer, | |
token=None, | |
idempotent=True, | |
) | |
def tune(self, scales): | |
assert len(scales) == len( | |
self.ranklist | |
), "Scale list should be the same length as ranklist" | |
diags = [] | |
for scale, rank in zip(scales, self.ranklist): | |
diags = diags + [scale] * rank | |
set_lora_diag(self.pipe.unet, torch.tensor(diags)) | |
def prompt(self, prompt): | |
if prompt is not None: | |
for idx, tok_size in enumerate(self.token_size_list): | |
prompt = prompt.replace( | |
f"<{idx + 1}>", | |
"".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]), | |
) | |
# TODO : Rescale LoRA + Text inputs based on prompt scale params | |
return prompt | |