nvan15's picture
Batch upload part 19
b816a2c verified
import torch
import torch.nn as nn
from typing import Optional, Set
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
def inverse_2x2(matrices):
# Extract matrix elements
# matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
a = matrices[..., 0, 0]
b = matrices[..., 0, 1]
c = matrices[..., 1, 0]
d = matrices[..., 1, 1]
# Compute determinant
det = a * d - b * c
# Compute inverse using the formula:
# inv = (1/det) * [[d, -b], [-c, a]]
inv_det = 1.0 / det
# Create output tensor
inv_matrices = torch.empty_like(matrices)
inv_matrices[..., 0, 0] = d * inv_det
inv_matrices[..., 0, 1] = -b * inv_det
inv_matrices[..., 1, 0] = -c * inv_det
inv_matrices[..., 1, 1] = a * inv_det
return inv_matrices
class Rotation(nn.Module):
"""
Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
This layer implements orthogonal fine-tuning through Cayley transformation:
h(x) = (I - A)^{-1} (I + A) x
where A = XY^T with X = [U; -V] and Y = [V; U]
"""
def __init__(self, r, dim, T=1.0, num_rotations=4):
super().__init__()
self.r = r
self.T = T
self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
self.num_rotations = num_rotations
def forward(self, x):
"""
Apply Cayley transformation to input x.
A = XY^T where X = [U; -V], Y = [V; U]
Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
Uses Woodbury identity for efficient computation:
(I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
Args:
x: Input tensor of shape (..., dim)
Returns:
Transformed tensor of shape (..., dim)
"""
x_dtype = x.dtype
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
I_minus_YX = I_2r - Y_T_X
if self.r == 1:
I_minus_YX_inv = inverse_2x2(I_minus_YX)
else:
# make it float32
I_minus_YX = I_minus_YX.to(torch.float32)
I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r)
I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r)
I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)
second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim)
second_term = second_term.sum(dim=-2) # Sum over rotations
output = x + 2 * second_term # Shape: (batch*seq_len, dim)
return output
def get_delta_weight(self):
"""
Compute the delta weight matrix induced by the rotation layer.
Returns:
Delta weight matrix of shape (dim, dim)
"""
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
I_minus_YX = I_2r - Y_T_X
if self.r == 1:
I_minus_YX_inv = inverse_2x2(I_minus_YX)
I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
else:
I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim)
I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
# I_minus_YX_float = I_minus_YX.float()
# I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r)
# I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
# I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim)
second_term = second_term.sum(dim=0)
total_delta_weight = 2 * second_term
return total_delta_weight
class RotationLayer(BaseTunerLayer):
"""
Adapter-like wrapper that attaches Rotation modules to a base linear layer.
"""
adapter_layer_names: tuple[str, ...] = ("rotation",)
other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")
def __init__(self, base_layer: nn.Module, **kwargs):
# Let BaseTunerLayer do its init (it usually subclasses nn.Module)
super().__init__()
# store base layer and adapter containers
self.base_layer = base_layer
self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module
self.scaling={} # default scaling per adapter
self._adapter_config = {} # store r, T, num_rotations per adapter
# flags (exposed in a simple way)
self._disable_adapters = False
self.merged_adapters: list[str] = []
self._cast_input_dtype_enabled = True
self.kwargs = kwargs
if isinstance(base_layer, nn.Linear):
self.in_features = base_layer.in_features
self.out_features = base_layer.out_features
else:
raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")
@property
def _available_adapters(self) -> set[str]:
return set(self.rotation.keys())
@property
def disable_adapters(self) -> bool:
return self._disable_adapters
@property
def merged(self) -> bool:
return bool(self.merged_adapters)
@property
def active_adapters(self) -> list[str]:
# If some external mechanism sets active adapters, prefer it; else use all added adapters.
return getattr(self, "_active_adapters", list(self.rotation.keys()))
def get_base_layer(self) -> nn.Module:
return self.base_layer
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
if not self._cast_input_dtype_enabled:
return x
return x.to(dtype)
def update_layer(
self,
adapter_name: str,
r: int,
T: float,
num_rotations: int,
**kwargs,
):
"""
Add / update a rotation adapter for this layer.
"""
if r <= 0:
raise ValueError(f"r must be positive, got {r}")
if num_rotations <= 0:
raise ValueError(f"num_rotations must be positive, got {num_rotations}")
rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations)
self.rotation[adapter_name] = rot
self.scaling[adapter_name] = 1.0
self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}
# (optional) helper to set currently active adapters externally
def set_active_adapters(self, adapters: Optional[list[str]]):
if adapters is None:
if hasattr(self, "_active_adapters"):
delattr(self, "_active_adapters")
else:
self._active_adapters = adapters
class Linear(nn.Module, RotationLayer):
"""
A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
"""
def __init__(self,
base_layer: nn.Linear,
adapter_name: str,
r: int,
T: float,
num_rotations: int,
**kwargs):
super().__init__()
RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(
adapter_name=adapter_name,
r=r,
T=T,
num_rotations=num_rotations,
**kwargs,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
"""
Merge the adapter effect into the base layer weights:
W_merged = W @ R
where R = I + delta (delta returned by get_delta_weight()).
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
# base_layer.weight shape: (out_features, in_features)
W = base_layer.weight.data # (out, in)
for active_adapter in adapter_names:
if active_adapter not in self._available_adapters:
continue
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in)
# merged W = W @ R
merged_W = W.to(R.dtype) @ R
if safe_merge and not torch.isfinite(merged_W).all():
raise ValueError("Merging resulted in non-finite weights. Aborting merge.")
base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
# mark merged (so unmerge can restore by inverse)
self.merged_adapters.append(active_adapter)
def unmerge(self):
"""
Reverse merges in LIFO order (pop merged adapters and invert R).
"""
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
while self.merged_adapters:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self._available_adapters:
continue
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
R_inv = torch.linalg.inv(R)
merged_W = base_layer.weight.data.to(R.dtype)
unmerged_W = merged_W @ R_inv
base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
x_dtype = x.dtype
base_layer = self.get_base_layer()
if self.disable_adapters:
# if merged, unmerge to ensure base_layer produces original behavior
if self.merged:
self.unmerge()
return base_layer(x, *args, **kwargs).to(x_dtype)
if self.merged:
# if merged into base layer, just forward
return base_layer(x, *args, **kwargs).to(x_dtype)
# otherwise apply active adapters (transform inputs) then call base layer
for active_adapter in self.active_adapters:
if active_adapter not in self.rotation:
continue
rotation = self.rotation[active_adapter]
x = self._cast_input_dtype(x, rotation.U.dtype)
x = rotation(x)
return base_layer(x, *args, **kwargs).to(x_dtype)
def __repr__(self):
return f"rotation.{super().__repr__()}"