Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Optional | |
| import torch | |
| import comfy.model_management | |
| from .base import WeightAdapterBase, weight_decompose | |
| class BOFTAdapter(WeightAdapterBase): | |
| name = "boft" | |
| def __init__(self, loaded_keys, weights): | |
| self.loaded_keys = loaded_keys | |
| self.weights = weights | |
| def load( | |
| cls, | |
| x: str, | |
| lora: dict[str, torch.Tensor], | |
| alpha: float, | |
| dora_scale: torch.Tensor, | |
| loaded_keys: set[str] = None, | |
| ) -> Optional["BOFTAdapter"]: | |
| if loaded_keys is None: | |
| loaded_keys = set() | |
| blocks_name = "{}.oft_blocks".format(x) | |
| rescale_name = "{}.rescale".format(x) | |
| blocks = None | |
| if blocks_name in lora.keys(): | |
| blocks = lora[blocks_name] | |
| if blocks.ndim == 4: | |
| loaded_keys.add(blocks_name) | |
| else: | |
| blocks = None | |
| if blocks is None: | |
| return None | |
| rescale = None | |
| if rescale_name in lora.keys(): | |
| rescale = lora[rescale_name] | |
| loaded_keys.add(rescale_name) | |
| weights = (blocks, rescale, alpha, dora_scale) | |
| return cls(loaded_keys, weights) | |
| def calculate_weight( | |
| self, | |
| weight, | |
| key, | |
| strength, | |
| strength_model, | |
| offset, | |
| function, | |
| intermediate_dtype=torch.float32, | |
| original_weight=None, | |
| ): | |
| v = self.weights | |
| blocks = v[0] | |
| rescale = v[1] | |
| alpha = v[2] | |
| dora_scale = v[3] | |
| blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) | |
| if rescale is not None: | |
| rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) | |
| boft_m, block_num, boft_b, *_ = blocks.shape | |
| try: | |
| # Get r | |
| I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype) | |
| # for Q = -Q^T | |
| q = blocks - blocks.transpose(-1, -2) | |
| normed_q = q | |
| if alpha > 0: # alpha in boft/bboft is for constraint | |
| q_norm = torch.norm(q) + 1e-8 | |
| if q_norm > alpha: | |
| normed_q = q * alpha / q_norm | |
| # use float() to prevent unsupported type in .inverse() | |
| r = (I + normed_q) @ (I - normed_q).float().inverse() | |
| r = r.to(weight) | |
| inp = org = weight | |
| r_b = boft_b//2 | |
| for i in range(boft_m): | |
| bi = r[i] | |
| g = 2 | |
| k = 2**i * r_b | |
| if strength != 1: | |
| bi = bi * strength + (1-strength) * I | |
| inp = ( | |
| inp.unflatten(0, (-1, g, k)) | |
| .transpose(1, 2) | |
| .flatten(0, 2) | |
| .unflatten(0, (-1, boft_b)) | |
| ) | |
| inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) | |
| inp = ( | |
| inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) | |
| ) | |
| if rescale is not None: | |
| inp = inp * rescale | |
| lora_diff = inp - org | |
| lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) | |
| if dora_scale is not None: | |
| weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) | |
| else: | |
| weight += function((strength * lora_diff).type(weight.dtype)) | |
| except Exception as e: | |
| logging.error("ERROR {} {} {}".format(self.name, key, e)) | |
| return weight | |