File size: 5,430 Bytes
			
			62bb9d8  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162  | 
								import logging
from typing import Optional
import torch
import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
class OFTDiff(WeightAdapterTrainBase):
    def __init__(self, weights):
        super().__init__()
        # Unpack weights tuple from LoHaAdapter
        blocks, rescale, alpha, _ = weights
        # Create trainable parameters
        self.oft_blocks = torch.nn.Parameter(blocks)
        if rescale is not None:
            self.rescale = torch.nn.Parameter(rescale)
            self.rescaled = True
        else:
            self.rescaled = False
        self.block_num, self.block_size, _ = blocks.shape
        self.constraint = float(alpha)
        self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
    def __call__(self, w):
        org_dtype = w.dtype
        I = torch.eye(self.block_size, device=self.oft_blocks.device)
        ## generate r
        # for Q = -Q^T
        q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
        normed_q = q
        if self.constraint:
            q_norm = torch.norm(q) + 1e-8
            if q_norm > self.constraint:
                normed_q = q * self.constraint / q_norm
        # use float() to prevent unsupported type
        r = (I + normed_q) @ (I - normed_q).float().inverse()
        ## Apply chunked matmul on weight
        _, *shape = w.shape
        org_weight = w.to(dtype=r.dtype)
        org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
        # Init R=0, so add I on it to ensure the output of step0 is original model output
        weight = torch.einsum(
            "k n m, k n ... -> k m ...",
            r,
            org_weight,
        ).flatten(0, 1)
        if self.rescaled:
            weight = self.rescale * weight
        return weight.to(org_dtype)
    def passive_memory_usage(self):
        """Calculates memory usage of the trainable parameters."""
        return sum(param.numel() * param.element_size() for param in self.parameters())
class OFTAdapter(WeightAdapterBase):
    name = "oft"
    def __init__(self, loaded_keys, weights):
        self.loaded_keys = loaded_keys
        self.weights = weights
    @classmethod
    def create_train(cls, weight, rank=1, alpha=1.0):
        out_dim = weight.shape[0]
        block_size, block_num = factorization(out_dim, rank)
        block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
        return OFTDiff(
            (block, None, alpha, None)
        )
    def to_train(self):
        return OFTDiff(self.weights)
    @classmethod
    def load(
        cls,
        x: str,
        lora: dict[str, torch.Tensor],
        alpha: float,
        dora_scale: torch.Tensor,
        loaded_keys: set[str] = None,
    ) -> Optional["OFTAdapter"]:
        if loaded_keys is None:
            loaded_keys = set()
        blocks_name = "{}.oft_blocks".format(x)
        rescale_name = "{}.rescale".format(x)
        blocks = None
        if blocks_name in lora.keys():
            blocks = lora[blocks_name]
            if blocks.ndim == 3:
                loaded_keys.add(blocks_name)
            else:
                blocks = None
        if blocks is None:
            return None
        rescale = None
        if rescale_name in lora.keys():
            rescale = lora[rescale_name]
            loaded_keys.add(rescale_name)
        weights = (blocks, rescale, alpha, dora_scale)
        return cls(loaded_keys, weights)
    def calculate_weight(
        self,
        weight,
        key,
        strength,
        strength_model,
        offset,
        function,
        intermediate_dtype=torch.float32,
        original_weight=None,
    ):
        v = self.weights
        blocks = v[0]
        rescale = v[1]
        alpha = v[2]
        if alpha is None:
            alpha = 0
        dora_scale = v[3]
        blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
        if rescale is not None:
            rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
        block_num, block_size, *_ = blocks.shape
        try:
            # Get r
            I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
            # for Q = -Q^T
            q = blocks - blocks.transpose(1, 2)
            normed_q = q
            if alpha > 0: # alpha in oft/boft is for constraint
                q_norm = torch.norm(q) + 1e-8
                if q_norm > alpha:
                    normed_q = q * alpha / q_norm
            # use float() to prevent unsupported type in .inverse()
            r = (I + normed_q) @ (I - normed_q).float().inverse()
            r = r.to(weight)
            _, *shape = weight.shape
            lora_diff = torch.einsum(
                "k n m, k n ... -> k m ...",
                (r * strength) - strength * I,
                weight.view(block_num, block_size, *shape),
            ).view(-1, *shape)
            if dora_scale is not None:
                weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
            else:
                weight += function((strength * lora_diff).type(weight.dtype))
        except Exception as e:
            logging.error("ERROR {} {} {}".format(self.name, key, e))
        return weight
 |