|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Sequence, Tuple |
|
|
|
import torch as T |
|
|
|
|
|
class SVDCompressedLinearFunc(T.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, input: T.Tensor, lhs: T.Tensor, |
|
rhs: T.Tensor, bias: Optional[T.Tensor] = None) -> T.Tensor: |
|
|
|
|
|
output = (input @ lhs) @ rhs |
|
if bias is not None: |
|
output += bias[None, :] |
|
ctx.bias = bias is not None |
|
ctx.save_for_backward(input, lhs, rhs) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output: Sequence[T.Tensor]): |
|
input, lhs, rhs = ctx.saved_tensors |
|
|
|
|
|
inp_size = lhs.shape[0] |
|
out_size = rhs.shape[1] |
|
input_shape = input.shape |
|
input = input.reshape(-1, inp_size) |
|
grad_output = grad_output.reshape(-1, out_size) |
|
|
|
input_grad = None |
|
if ctx.needs_input_grad[0]: |
|
input_grad = (grad_output @ rhs.T) @ lhs.T |
|
|
|
lhs_grad = None |
|
if ctx.needs_input_grad[1]: |
|
|
|
|
|
lhs_grad = input.T @ (grad_output @ rhs.T) |
|
|
|
rhs_grad = None |
|
if ctx.needs_input_grad[2]: |
|
|
|
rhs_grad = (input @ lhs).T @ grad_output |
|
|
|
bias_grad = None |
|
if ctx.needs_input_grad[3]: |
|
bias_grad = grad_output.sum(axis=0) |
|
|
|
|
|
input_grad = input_grad.reshape(input_shape) |
|
return input_grad, lhs_grad, rhs_grad, bias_grad |
|
|
|
|
|
compressed_linear_svd = SVDCompressedLinearFunc.apply |
|
|
|
|
|
class SVDCompressedLinear(T.nn.Module): |
|
"""Class SVDCompressedLinear is a layer which represents a weight matrix of |
|
lineaer layer in factorized view. |
|
|
|
>>> linear_layer = T.nn.Linear(10, 20) |
|
>>> svd_layer = SVDCompressedLinear.from_linear(linear_layer, rank=5) |
|
""" |
|
|
|
def __init__(self, factors: Tuple[T.Tensor, T.Tensor, T.Tensor], |
|
bias: Optional[T.Tensor] = None): |
|
super().__init__() |
|
|
|
|
|
|
|
scale = T.sqrt(factors[1]) |
|
|
|
|
|
self.lhs = T.nn.Parameter(factors[2].T * scale[None, :]) |
|
self.rhs = T.nn.Parameter(factors[0].T * scale[:, None]) |
|
|
|
self.bias = None |
|
if bias is not None: |
|
self.bias = T.nn.Parameter(bias) |
|
|
|
self.in_features = self.lhs.shape[0] |
|
self.out_features = self.rhs.shape[1] |
|
|
|
@classmethod |
|
def from_linear(cls, linear: T.nn.Linear, rank: Optional[int] = None, |
|
tol: float = 1e-6): |
|
with T.no_grad(): |
|
data = linear.weight.data |
|
lhs, vals, rhs = T.linalg.svd(data) |
|
if rank is None: |
|
raise NotImplementedError |
|
else: |
|
lhs = lhs[:, :rank] |
|
rhs = rhs[:rank, :] |
|
vals = vals[:rank] |
|
|
|
bias = None |
|
if linear.bias is not None: |
|
bias = T.clone(linear.bias.data) |
|
|
|
return SVDCompressedLinear((lhs, vals, rhs), bias) |
|
|
|
@classmethod |
|
def from_random(cls, in_features: int, out_features: int, rank: int, |
|
bias: bool = True): |
|
lvecs = T.randn((out_features, rank)) |
|
svals = T.ones(rank) |
|
rvecs = T.randn((rank, in_features)) |
|
bias_term = None |
|
if bias: |
|
bias_term = T.randn(out_features) |
|
return SVDCompressedLinear((lvecs, svals, rvecs), bias_term) |
|
|
|
def forward(self, input: T.Tensor) -> T.Tensor: |
|
return compressed_linear_svd(input, self.lhs, self.rhs, self.bias) |
|
|