Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,606 Bytes
d711508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from .layer import AdaLoraLayer
if is_bnb_available():
class SVDLinear8bitLt(torch.nn.Module, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.get_base_layer().weight.requires_grad = False
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# note: no check for self.merged because merging is not supported (yet)
result = self.base_layer(x)
if self.disable_adapters:
return result
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling / ranknum
# inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
result = result + output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
if is_bnb_4bit_available():
class SVDLinear4bit(torch.nn.Module, AdaLoraLayer):
# Low-rank matrix for SVD-based adaptation
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
init_lora_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
# Freezing the pre-trained weight matrix
self.get_base_layer().weight.requires_grad = False
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
# note: no check for self.merged because merging is not supported (yet)
result = self.base_layer(x, *args, **kwargs)
if self.disable_adapters:
return result
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling / ranknum
result += output
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep
|