Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023-present the HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import math | |
from typing import Any, Optional, Set, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from peft.tuners.lycoris_utils import LycorisLayer | |
class LoKrLayer(nn.Module, LycorisLayer): | |
# All names of layers that may contain adapter weights | |
adapter_layer_names = ( | |
"lokr_w1", | |
"lokr_w1_a", | |
"lokr_w1_b", | |
"lokr_w2", | |
"lokr_w2_a", | |
"lokr_w2_b", | |
"lokr_t2", | |
) | |
# other_param_names is defined on parent class | |
def __init__(self, base_layer: nn.Module) -> None: | |
super().__init__() | |
LycorisLayer.__init__(self, base_layer) | |
# LoKr info | |
self.lokr_w1 = nn.ParameterDict({}) | |
self.lokr_w1_a = nn.ParameterDict({}) | |
self.lokr_w1_b = nn.ParameterDict({}) | |
self.lokr_w2 = nn.ParameterDict({}) | |
self.lokr_w2_a = nn.ParameterDict({}) | |
self.lokr_w2_b = nn.ParameterDict({}) | |
self.lokr_t2 = nn.ParameterDict({}) | |
def _available_adapters(self) -> Set[str]: | |
return { | |
*self.lokr_w1, | |
*self.lokr_w1_a, | |
*self.lokr_w1_b, | |
*self.lokr_w2, | |
*self.lokr_w2_a, | |
*self.lokr_w2_b, | |
*self.lokr_t2, | |
} | |
def create_adapter_parameters( | |
self, | |
adapter_name: str, | |
r: int, | |
shape, | |
use_w1: bool, | |
use_w2: bool, | |
use_effective_conv2d: bool, | |
): | |
if use_w1: | |
self.lokr_w1[adapter_name] = nn.Parameter(torch.empty(shape[0][0], shape[1][0])) | |
else: | |
self.lokr_w1_a[adapter_name] = nn.Parameter(torch.empty(shape[0][0], r)) | |
self.lokr_w1_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][0])) | |
if len(shape) == 4: | |
# Conv2d | |
if use_w2: | |
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1], *shape[2:])) | |
elif use_effective_conv2d: | |
self.lokr_t2[adapter_name] = nn.Parameter(torch.empty(r, r, shape[2], shape[3])) | |
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(r, shape[0][1])) # b, 1-mode | |
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) # d, 2-mode | |
else: | |
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) | |
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1] * shape[2] * shape[3])) | |
else: | |
# Linear | |
if use_w2: | |
self.lokr_w2[adapter_name] = nn.Parameter(torch.empty(shape[0][1], shape[1][1])) | |
else: | |
self.lokr_w2_a[adapter_name] = nn.Parameter(torch.empty(shape[0][1], r)) | |
self.lokr_w2_b[adapter_name] = nn.Parameter(torch.empty(r, shape[1][1])) | |
def reset_adapter_parameters(self, adapter_name: str): | |
if adapter_name in self.lokr_w1: | |
nn.init.zeros_(self.lokr_w1[adapter_name]) | |
else: | |
nn.init.zeros_(self.lokr_w1_a[adapter_name]) | |
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) | |
if adapter_name in self.lokr_w2: | |
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) | |
else: | |
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) | |
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) | |
if adapter_name in self.lokr_t2: | |
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) | |
def reset_adapter_parameters_random(self, adapter_name: str): | |
if adapter_name in self.lokr_w1: | |
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5)) | |
else: | |
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5)) | |
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5)) | |
if adapter_name in self.lokr_w2: | |
nn.init.kaiming_uniform_(self.lokr_w2[adapter_name], a=math.sqrt(5)) | |
else: | |
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5)) | |
nn.init.kaiming_uniform_(self.lokr_w2_b[adapter_name], a=math.sqrt(5)) | |
if adapter_name in self.lokr_t2: | |
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5)) | |
def update_layer( | |
self, | |
adapter_name: str, | |
r: int, | |
alpha: float, | |
rank_dropout: float, | |
module_dropout: float, | |
init_weights: bool, | |
use_effective_conv2d: bool, | |
decompose_both: bool, | |
decompose_factor: int, | |
**kwargs, | |
) -> None: | |
"""Internal function to create lokr adapter | |
Args: | |
adapter_name (`str`): Name for the adapter to add. | |
r (`int`): Rank for the added adapter. | |
alpha (`float`): Alpha for the added adapter. | |
rank_dropout (`float`): The dropout probability for rank dimension during training | |
module_dropout (`float`): The dropout probability for disabling adapter during training. | |
init_weights (`bool`): Whether to initialize adapter weights. | |
use_effective_conv2d (`bool`): Use parameter effective decomposition for Conv2d with ksize > 1. | |
decompose_both (`bool`): Perform rank decomposition of left kronecker product matrix. | |
decompose_factor (`int`): Kronecker product decomposition factor. | |
""" | |
if r <= 0: | |
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") | |
self.r[adapter_name] = r | |
self.alpha[adapter_name] = alpha | |
self.scaling[adapter_name] = alpha / r | |
self.rank_dropout[adapter_name] = rank_dropout | |
self.module_dropout[adapter_name] = module_dropout | |
base_layer = self.get_base_layer() | |
# Determine shape of LoKr weights | |
if isinstance(base_layer, nn.Linear): | |
in_dim, out_dim = base_layer.in_features, base_layer.out_features | |
in_m, in_n = factorization(in_dim, decompose_factor) | |
out_l, out_k = factorization(out_dim, decompose_factor) | |
shape = ((out_l, out_k), (in_m, in_n)) # ((a, b), (c, d)), out_dim = a*c, in_dim = b*d | |
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) | |
use_w2 = not (r < max(shape[0][1], shape[1][1]) / 2) | |
use_effective_conv2d = False | |
elif isinstance(base_layer, nn.Conv2d): | |
in_dim, out_dim = base_layer.in_channels, base_layer.out_channels | |
k_size = base_layer.kernel_size | |
in_m, in_n = factorization(in_dim, decompose_factor) | |
out_l, out_k = factorization(out_dim, decompose_factor) | |
shape = ((out_l, out_k), (in_m, in_n), *k_size) # ((a, b), (c, d), *k_size) | |
use_w1 = not (decompose_both and r < max(shape[0][0], shape[1][0]) / 2) | |
use_w2 = r >= max(shape[0][1], shape[1][1]) / 2 | |
use_effective_conv2d = use_effective_conv2d and base_layer.kernel_size != (1, 1) | |
else: | |
raise TypeError(f"LoKr is not implemented for base layers of type {type(base_layer).__name__}") | |
# Create weights with provided shape | |
self.create_adapter_parameters(adapter_name, r, shape, use_w1, use_w2, use_effective_conv2d) | |
# Initialize weights | |
if init_weights: | |
self.reset_adapter_parameters(adapter_name) | |
else: | |
self.reset_adapter_parameters_random(adapter_name) | |
# Move new weights to device | |
weight = getattr(self.get_base_layer(), "weight", None) | |
if weight is not None: | |
# the layer is already completely initialized, this is an update | |
if weight.dtype.is_floating_point or weight.dtype.is_complex: | |
self.to(weight.device, dtype=weight.dtype) | |
else: | |
self.to(weight.device) | |
self.set_adapter(self.active_adapters) | |
def get_delta_weight(self, adapter_name: str) -> torch.Tensor: | |
# https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224 | |
if adapter_name in self.lokr_w1: | |
w1 = self.lokr_w1[adapter_name] | |
else: | |
w1 = self.lokr_w1_a[adapter_name] @ self.lokr_w1_b[adapter_name] | |
if adapter_name in self.lokr_w2: | |
w2 = self.lokr_w2[adapter_name] | |
elif adapter_name in self.lokr_t2: | |
w2 = make_weight_cp(self.lokr_t2[adapter_name], self.lokr_w2_a[adapter_name], self.lokr_w2_b[adapter_name]) | |
else: | |
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name] | |
# Make weights with Kronecker product | |
weight = make_kron(w1, w2) | |
weight = weight.reshape(self.get_base_layer().weight.shape) | |
# Perform rank dropout during training - drop rows of addition weights | |
rank_dropout = self.rank_dropout[adapter_name] | |
if self.training and rank_dropout: | |
drop = (torch.rand(weight.size(0)) > rank_dropout).float() | |
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device) | |
drop /= drop.mean() | |
weight *= drop | |
return weight | |
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: | |
previous_dtype = x.dtype | |
if self.disable_adapters: | |
if self.merged: | |
self.unmerge() | |
result = self.base_layer(x, *args, **kwargs) | |
elif self.merged: | |
result = self.base_layer(x, *args, **kwargs) | |
else: | |
result = self.base_layer(x, *args, **kwargs) | |
# Execute all the adapters | |
for active_adapter in self.active_adapters: | |
if active_adapter not in self._available_adapters: | |
continue | |
module_dropout = self.module_dropout[active_adapter] | |
# Modify current execution weights | |
if (not self.training) or (self.training and torch.rand(1) > module_dropout): | |
result = result + self._get_delta_activations(active_adapter, x, *args, **kwargs) | |
result = result.to(previous_dtype) | |
return result | |
class Linear(LoKrLayer): | |
"""LoKr implemented in Linear layer""" | |
def __init__( | |
self, | |
base_layer: nn.Module, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[torch.dtype] = None, | |
adapter_name: str = "default", | |
r: int = 0, | |
alpha: float = 0.0, | |
rank_dropout: float = 0.0, | |
module_dropout: float = 0.0, | |
init_weights: bool = True, | |
**kwargs, | |
): | |
super().__init__(base_layer) | |
# Create adapter and set it active | |
self._active_adapter = adapter_name | |
self.update_layer(adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, **kwargs) | |
def _get_delta_activations( | |
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any | |
) -> torch.Tensor: | |
delta_weight = self.get_delta_weight(adapter_name) | |
# don't add bias here, because the bias is already included in the output of the base_layer | |
return F.linear(input, delta_weight) | |
def __repr__(self) -> str: | |
rep = super().__repr__() | |
return "lokr." + rep | |
class Conv2d(LoKrLayer): | |
"""LoKr implemented in Conv2d layer""" | |
def __init__( | |
self, | |
base_layer: nn.Module, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[torch.dtype] = None, | |
adapter_name: str = "default", | |
r: int = 0, | |
alpha: float = 0.0, | |
rank_dropout: float = 0.0, | |
module_dropout: float = 0.0, | |
use_effective_conv2d: bool = False, | |
init_weights: bool = True, | |
**kwargs, | |
): | |
super().__init__(base_layer) | |
# Create adapter and set it active | |
self._active_adapter = adapter_name | |
self.update_layer( | |
adapter_name, r, alpha, rank_dropout, module_dropout, init_weights, use_effective_conv2d, **kwargs | |
) | |
def _get_delta_activations( | |
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any | |
) -> torch.Tensor: | |
delta_weight = self.get_delta_weight(adapter_name) | |
# don't add bias here, because the bias is already included in the output of the base_layer | |
base_layer = self.get_base_layer() | |
return F.conv2d( | |
input, | |
delta_weight, | |
stride=base_layer.stride, | |
padding=base_layer.padding, | |
dilation=base_layer.dilation, | |
groups=base_layer.groups, | |
) | |
def __repr__(self) -> str: | |
rep = super().__repr__() | |
return "lokr." + rep | |
# Below code is a direct copy from https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/lokr.py#L11 | |
def factorization(dimension: int, factor: int = -1) -> Tuple[int, int]: | |
"""Factorizes the provided number into the product of two numbers | |
Args: | |
dimension (`int`): The number that needs to be factorized. | |
factor (`int`, optional): | |
Factorization divider. The algorithm will try to output two numbers, one of each will be as close to the | |
factor as possible. If -1 is provided, the decomposition algorithm would try to search dividers near the | |
square root of the dimension. Defaults to -1. | |
Returns: | |
Tuple[`int`, `int`]: A tuple of two numbers, whose product is equal to the provided number. The first number is | |
always less than or equal to the second. | |
Example: | |
```py | |
>>> factorization(256, factor=-1) | |
(16, 16) | |
>>> factorization(128, factor=-1) | |
(8, 16) | |
>>> factorization(127, factor=-1) | |
(1, 127) | |
>>> factorization(128, factor=4) | |
(4, 32) | |
``` | |
""" | |
if factor > 0 and (dimension % factor) == 0: | |
m = factor | |
n = dimension // factor | |
return m, n | |
if factor == -1: | |
factor = dimension | |
m, n = 1, dimension | |
length = m + n | |
while m < n: | |
new_m = m + 1 | |
while dimension % new_m != 0: | |
new_m += 1 | |
new_n = dimension // new_m | |
if new_m + new_n > length or new_m > factor: | |
break | |
else: | |
m, n = new_m, new_n | |
if m > n: | |
n, m = m, n | |
return m, n | |
def make_weight_cp(t, wa, wb): | |
rebuild2 = torch.einsum("i j k l, i p, j r -> p r k l", t, wa, wb) # [c, d, k1, k2] | |
return rebuild2 | |
def make_kron(w1, w2, scale=1.0): | |
if len(w2.shape) == 4: | |
w1 = w1.unsqueeze(2).unsqueeze(2) | |
w2 = w2.contiguous() | |
rebuild = torch.kron(w1, w2) | |
return rebuild * scale | |