|
|
|
|
|
|
|
|
|
|
|
|
|
r""" |
|
Low Ranking Adaptation for LLMs scheme. |
|
|
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โ h โ |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โฒ |
|
| |
|
+ |
|
/ \ |
|
โโโโโโโโโโโโโโโโโโโ โญโโโโโโโโโโโโโโโโฎ Matrix initialization: |
|
โ โ \ B / B = 0 |
|
โ pretrained โ \ r*d / A = N(0, sigma^2) |
|
โ weights โ โฐโโโโโโโโโโฏ |
|
โ โ | r | r - rank |
|
โ W e R^(d*d) โ | โโโโโโโถ | |
|
โ โ โญโโโโโโโโโโฎ |
|
โโโโโโโโโโโโโโโโโโโ / A \ |
|
โฒ / d*r \ |
|
\ โฐโโโโโโโโโโโโโโโโฏ |
|
\ โฒ |
|
\ / |
|
\ / |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
โ x โ |
|
โโโโโโโโโโโโโโโโโโโโโ |
|
|
|
With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, |
|
we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates |
|
for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of |
|
course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen |
|
pretrained weights and thus fine-tune the model. |
|
|
|
The goal of this approach is to move weight updates into a separate matrix which is decomposed with |
|
two matrices of a lower rank. |
|
""" |
|
|
|
import math |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from typing_extensions import Self |
|
|
|
import lit_gpt |
|
from lit_gpt.config import Config as BaseConfig |
|
from lit_gpt.model import GPT as BaseModel |
|
from lit_gpt.model import Block as BaseBlock |
|
from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention |
|
from lit_gpt.model import KVCache |
|
from lit_gpt.utils import map_old_state_dict_weights |
|
|
|
|
|
class LoRALayer(nn.Module): |
|
def __init__(self, r: int, lora_alpha: int, lora_dropout: float): |
|
"""Store LoRA specific attributes in a class. |
|
|
|
Args: |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
""" |
|
super().__init__() |
|
assert r >= 0 |
|
self.r = r |
|
self.lora_alpha = lora_alpha |
|
|
|
if lora_dropout > 0.0: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
|
|
self.merged = False |
|
|
|
|
|
class LoRALinear(LoRALayer): |
|
|
|
def __init__( |
|
self, |
|
|
|
in_features: int, |
|
out_features: int, |
|
|
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
**kwargs, |
|
): |
|
"""LoRA wrapper around linear class. |
|
|
|
This class has three weight matrices: |
|
1. Pretrained weights are stored as `self.linear.weight` |
|
2. LoRA A matrix as `self.lora_A` |
|
3. LoRA B matrix as `self.lora_B` |
|
Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
|
|
|
Args: |
|
in_features: number of input features of the pretrained weights |
|
out_features: number of output features of the pretrained weights |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
""" |
|
super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
|
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
|
|
|
|
if r > 0: |
|
self.lora_A = nn.Parameter(torch.zeros((r, in_features))) |
|
self.lora_B = nn.Parameter(torch.zeros((out_features, r))) |
|
self.scaling = self.lora_alpha / self.r |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
"""Reset all the weights, even including pretrained ones.""" |
|
if hasattr(self, "lora_A"): |
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B) |
|
|
|
def merge(self) -> None: |
|
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
|
if self.r > 0 and not self.merged: |
|
|
|
self.linear.weight.data += (self.lora_B @ self.lora_A) * self.scaling |
|
self.merged = True |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
pretrained = self.linear(x) |
|
if self.r == 0 or self.merged: |
|
return pretrained |
|
lora = (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling |
|
return pretrained + lora |
|
|
|
|
|
class LoRAQKVLinear(LoRALinear): |
|
|
|
def __init__( |
|
self, |
|
|
|
in_features: int, |
|
out_features: int, |
|
|
|
n_head: int, |
|
n_query_groups: int, |
|
r: int = 0, |
|
lora_alpha: int = 1, |
|
lora_dropout: float = 0.0, |
|
enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, |
|
**kwargs, |
|
): |
|
"""LoRA wrapper around linear class that is used for calculation of q, k and v matrices. |
|
|
|
This class has three weight matrices: |
|
1. Pretrained weights are stored as `self.linear.weight` |
|
2. LoRA A matrix as `self.lora_A` |
|
3. LoRA B matrix as `self.lora_B` |
|
Only LoRA's A and B matrices are updated, pretrained weights stay frozen. |
|
|
|
Args: |
|
in_features: number of input features of the pretrained weights |
|
out_features: number of output features of the pretrained weights |
|
n_head: number of attention heads |
|
n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
lora_alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we |
|
don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` |
|
and `value` but keep `key` without weight updates we should pass `[True, False, True]` |
|
""" |
|
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) |
|
self.linear = torch.nn.Linear(in_features, out_features, **kwargs) |
|
self.n_head = n_head |
|
self.n_query_groups = n_query_groups |
|
if isinstance(enable_lora, bool): |
|
enable_lora = [enable_lora] * 3 |
|
assert len(enable_lora) == 3 |
|
self.enable_lora = enable_lora |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if r > 0 and any(enable_lora): |
|
self.lora_A = nn.Parameter(torch.zeros((r * sum(enable_lora), in_features))) |
|
enable_q, enable_k, enable_v = enable_lora |
|
self.kv_embd_size = self.linear.in_features // (n_head // n_query_groups) |
|
|
|
qkv_shapes = ( |
|
self.linear.in_features * enable_q, |
|
self.kv_embd_size * enable_k, |
|
self.kv_embd_size * enable_v, |
|
) |
|
self.qkv_shapes = [s for s in qkv_shapes if s] |
|
self.lora_B = nn.Parameter(torch.zeros(sum(self.qkv_shapes), r)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.scaling = self.lora_alpha / self.r |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lora_ind = [] |
|
if enable_q: |
|
self.lora_ind.extend(range(0, self.linear.in_features)) |
|
if enable_k: |
|
self.lora_ind.extend(range(self.linear.in_features, self.linear.in_features + self.kv_embd_size)) |
|
if enable_v: |
|
self.lora_ind.extend(range(self.linear.in_features + self.kv_embd_size, self.linear.out_features)) |
|
self.reset_parameters() |
|
|
|
def zero_pad(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Properly pad weight updates with zeros. |
|
|
|
If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, |
|
then the weights update should be: |
|
|
|
[[ฮW,ฮW,ฮW, ..., 0,0,0, ..., ฮW,ฮW,ฮW,], |
|
[....................................], |
|
[ฮW,ฮW,ฮW, ..., 0,0,0, ..., ฮW,ฮW,ฮW,]] |
|
โ โ โ |
|
________________________________________ |
|
| query | key | value | |
|
---------------------------------------- |
|
|
|
Args: |
|
x: tensor with weights update that will be padded with zeros if necessary |
|
|
|
Returns: |
|
A tensor with weight updates and zeros for deselected q, k or v |
|
""" |
|
|
|
if all(self.enable_lora): |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
result = x.new_zeros((*x.shape[:-1], self.linear.out_features)) |
|
result = result.view(-1, self.linear.out_features) |
|
result = result.index_copy( |
|
1, torch.tensor(self.lora_ind, device=result.device), x.reshape(-1, sum(self.qkv_shapes)) |
|
) |
|
return result.view((*x.shape[:-1], self.linear.out_features)).transpose(0, 1) |
|
|
|
def conv1d(self, input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
|
"""An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. |
|
|
|
If the number of heads is equal to the number of query groups - grouped queries are disabled |
|
(see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized |
|
query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the |
|
input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple |
|
conv layers side by side). |
|
|
|
Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, |
|
apply each part of the weight matrix to the corresponding input's part and concatenate the result. |
|
|
|
Args: |
|
input: input matrix of shape (B, C, T) |
|
weight: weight matrix of shape (C_output, rank, 1). |
|
"C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). |
|
|
|
Returns: |
|
A tensor with a shape (B, C_output, T) |
|
|
|
""" |
|
if self.n_head == self.n_query_groups: |
|
return F.conv1d(input, weight, groups=sum(self.enable_lora)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_splitted = input.chunk(sum(self.enable_lora), dim=1) |
|
weight_splitted = weight.split(self.qkv_shapes) |
|
return torch.cat( |
|
[F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], dim=1 |
|
) |
|
|
|
def merge(self) -> None: |
|
"""Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" |
|
|
|
|
|
|
|
|
|
|
|
if self.r > 0 and any(self.enable_lora) and not self.merged: |
|
delta_w = self.conv1d( |
|
self.lora_A.data.unsqueeze(0), |
|
self.lora_B.data.unsqueeze(-1), |
|
).squeeze( |
|
0 |
|
) |
|
|
|
self.linear.weight.data += self.zero_pad(delta_w * self.scaling) |
|
self.merged = True |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Do the forward pass. |
|
|
|
If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. |
|
If not, then multiply pretrained weights with input, apply LoRA on input and do summation. |
|
|
|
Args: |
|
x: input tensor of shape (batch_size, context_length, embedding_size) |
|
|
|
Returns: |
|
Output tensor of shape (batch_size, context_length, 3 * embedding_size) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pretrained = self.linear(x) |
|
if self.r == 0 or not any(self.enable_lora) or self.merged: |
|
return pretrained |
|
after_A = F.linear(self.lora_dropout(x), self.lora_A) |
|
|
|
|
|
|
|
after_B = self.conv1d( |
|
after_A.transpose(-2, -1), |
|
self.lora_B.unsqueeze(-1), |
|
).transpose( |
|
-2, -1 |
|
) |
|
lora = self.zero_pad(after_B) * self.scaling |
|
return pretrained + lora |
|
|
|
|
|
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: |
|
"""Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. |
|
|
|
Args: |
|
model: model with LoRA layers |
|
bias: |
|
``"none"``: all bias weights will be frozen, |
|
``"lora_only"``: only bias weight for LoRA layers will be unfrozen, |
|
``"all"``: all bias weights will be unfrozen. |
|
|
|
Raises: |
|
NotImplementedError: if `bias` not in ["none", "lora_only", "all"] |
|
""" |
|
|
|
for n, p in model.named_parameters(): |
|
if "lora_" not in n: |
|
p.requires_grad = False |
|
|
|
|
|
if bias == "none": |
|
return |
|
if bias == "all": |
|
for n, p in model.named_parameters(): |
|
if "bias" in n: |
|
p.requires_grad = True |
|
elif bias == "lora_only": |
|
for m in model.modules(): |
|
if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: |
|
m.bias.requires_grad = True |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
def lora_filter(key: str, value: Any) -> bool: |
|
return "lora_" in key |
|
|
|
|
|
@dataclass |
|
class Config(BaseConfig): |
|
""" |
|
Args: |
|
r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of |
|
the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) |
|
alpha: alpha is needed for scaling updates as alpha/r |
|
"This scaling helps to reduce the need to retune hyperparameters when we vary r" |
|
https://arxiv.org/pdf/2106.09685.pdf (section 4.1) |
|
dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) |
|
to_*: either apply LoRA to the specified weights or not |
|
""" |
|
|
|
r: int = 0 |
|
alpha: int = 1 |
|
dropout: float = 0.0 |
|
to_query: bool = False |
|
to_key: bool = False |
|
to_value: bool = False |
|
to_projection: bool = False |
|
to_mlp: bool = False |
|
to_head: bool = False |
|
|
|
@property |
|
def mlp_class(self) -> Type: |
|
return getattr(lit_gpt.lora, self._mlp_class) |
|
|
|
|
|
class GPT(BaseModel): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
assert config.padded_vocab_size is not None |
|
self.config = config |
|
|
|
self.lm_head = LoRALinear( |
|
config.n_embd, |
|
config.padded_vocab_size, |
|
bias=config.lm_head_bias, |
|
r=(config.r if config.to_head else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
|
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
|
) |
|
) |
|
self.max_seq_length = self.config.block_size |
|
self.mask_cache: Optional[torch.Tensor] = None |
|
|
|
def forward( |
|
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, maxlen: int = None |
|
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
T = idx.size(1) if maxlen is None else maxlen |
|
if self.max_seq_length < T: |
|
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") |
|
|
|
if input_pos is not None: |
|
cos = self.cos.index_select(0, input_pos) |
|
sin = self.sin.index_select(0, input_pos) |
|
if self.mask_cache is None: |
|
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
mask = self.mask_cache.index_select(2, input_pos) |
|
else: |
|
cos = self.cos[:T] |
|
sin = self.sin[:T] |
|
mask = None |
|
|
|
if type(idx) is tuple: |
|
|
|
stack_before_tokens_x, motion_tokens, before_len = idx |
|
|
|
|
|
|
|
|
|
x = self.transformer.wte(stack_before_tokens_x) |
|
|
|
for i in range(len(x)): |
|
x[i][before_len[i]: before_len[i] + len(motion_tokens[i])] = motion_tokens[i] |
|
else: |
|
x = self.transformer.wte(idx) |
|
for block in self.transformer.h: |
|
x = block(x, cos, sin, mask, input_pos) |
|
x = self.transformer.ln_f(x) |
|
if lm_head_chunk_size > 0: |
|
|
|
return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] |
|
return self.lm_head(x) |
|
|
|
@classmethod |
|
def from_name(cls, name: str, **kwargs: Any) -> Self: |
|
return cls(Config.from_name(name, **kwargs)) |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
"""Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" |
|
super()._init_weights(module) |
|
if isinstance(module, LoRALinear): |
|
module.reset_parameters() |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = {"lm_head.weight": "lm_head.linear.weight"} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class Block(BaseBlock): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
self.attn = CausalSelfAttention(config) |
|
if not config.shared_attention_norm: |
|
self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) |
|
self.mlp = config.mlp_class(config) |
|
|
|
self.config = config |
|
|
|
|
|
class CausalSelfAttention(BaseCausalSelfAttention): |
|
def __init__(self, config: Config) -> None: |
|
|
|
|
|
nn.Module.__init__(self) |
|
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
|
|
self.attn = LoRAQKVLinear( |
|
in_features=config.n_embd, |
|
out_features=shape, |
|
r=config.r, |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
enable_lora=(config.to_query, config.to_key, config.to_value), |
|
bias=config.bias, |
|
|
|
n_head=config.n_head, |
|
n_query_groups=config.n_query_groups, |
|
) |
|
|
|
self.proj = LoRALinear( |
|
config.n_embd, |
|
config.n_embd, |
|
bias=config.bias, |
|
r=(config.r if config.to_projection else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
|
|
self.kv_cache: Optional[KVCache] = None |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"attn.weight": "attn.linear.weight", |
|
"attn.bias": "attn.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
self.fc = LoRALinear( |
|
config.n_embd, |
|
config.intermediate_size, |
|
bias=config.bias, |
|
r=(config.r if config.to_mlp else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
self.proj = LoRALinear( |
|
config.intermediate_size, |
|
config.n_embd, |
|
bias=config.bias, |
|
r=(config.r if config.to_mlp else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
|
|
self.config = config |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"fc.weight": "fc.linear.weight", |
|
"fc.bias": "fc.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
class LLaMAMLP(lit_gpt.model.LLaMAMLP): |
|
def __init__(self, config: Config) -> None: |
|
nn.Module.__init__(self) |
|
self.fc_1 = LoRALinear( |
|
config.n_embd, |
|
config.intermediate_size, |
|
bias=config.bias, |
|
r=(config.r if config.to_mlp else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
self.fc_2 = LoRALinear( |
|
config.n_embd, |
|
config.intermediate_size, |
|
bias=config.bias, |
|
r=(config.r if config.to_mlp else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
self.proj = LoRALinear( |
|
config.intermediate_size, |
|
config.n_embd, |
|
bias=config.bias, |
|
r=(config.r if config.to_mlp else 0), |
|
lora_alpha=config.alpha, |
|
lora_dropout=config.dropout, |
|
) |
|
|
|
def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: |
|
"""For compatibility with base checkpoints.""" |
|
mapping = { |
|
"fc_1.weight": "fc_1.linear.weight", |
|
"fc_1.bias": "fc_1.linear.bias", |
|
"fc_2.weight": "fc_2.linear.weight", |
|
"fc_2.bias": "fc_2.linear.bias", |
|
"proj.weight": "proj.linear.weight", |
|
"proj.bias": "proj.linear.bias", |
|
} |
|
state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
def merge_lora_weights(model: GPT) -> None: |
|
"""Merge LoRA weights into the full-rank weights to speed up inference.""" |
|
for module in model.modules(): |
|
if isinstance(module, LoRALinear): |
|
module.merge() |
|
|