|
|
|
|
|
from typing import Optional, Sequence, Tuple |
|
|
|
import numpy as np |
|
import torch as T |
|
from opt_einsum import contract_expression |
|
from opt_einsum.contract import ContractExpression |
|
|
|
from .linalg import ttd |
|
|
|
|
|
def make_contraction(shape, rank, batch_size=32, |
|
seqlen=512) -> ContractExpression: |
|
ndim = len(rank) - 1 |
|
row_shape, col_shape = shape |
|
|
|
|
|
row_ix, col_ix = np.arange(2 * ndim).reshape(2, ndim) |
|
rank_ix = 2 * ndim + np.arange(ndim + 1) |
|
batch_ix = 4 * ndim |
|
|
|
|
|
cores_ix = np.column_stack([rank_ix[:-1], row_ix, col_ix, rank_ix[1:]]) |
|
cores_shape = zip(rank[:-1], row_shape, col_shape, rank[1:]) |
|
|
|
|
|
input_ix = np.insert(row_ix, 0, batch_ix) |
|
input_shape = (batch_size * seqlen, ) + row_shape |
|
|
|
|
|
output_ix = np.insert(col_ix, 0, batch_ix) |
|
output_ix = np.append(output_ix, (rank_ix[0], rank_ix[-1])) |
|
|
|
|
|
ops = [input_shape, input_ix] |
|
for core_ix, core_shape in zip(cores_ix, cores_shape): |
|
ops.append(core_shape) |
|
ops.append(core_ix) |
|
ops.append(output_ix) |
|
ops = [tuple(op) for op in ops] |
|
|
|
return contract_expression(*ops) |
|
|
|
|
|
class TTCompressedLinear(T.nn.Module): |
|
"""Class TTCompressedLinear is a layer which represents a weight matrix of |
|
linear layer in factorized view as tensor train matrix. |
|
|
|
>>> linear_layer = T.nn.Linear(6, 6) |
|
>>> tt_layer = TTCompressedLinear \ |
|
... .from_linear(linear_layer, rank=2, shape=((2, 3), (3, 2))) |
|
""" |
|
|
|
def __init__(self, cores: Sequence[T.Tensor], |
|
bias: Optional[T.Tensor] = None): |
|
super().__init__() |
|
|
|
for i, core in enumerate(cores): |
|
if core.ndim != 4: |
|
raise ValueError('Expected number of dimensions of the ' |
|
f'{i}-th core is 4 but given {cores.ndim}.') |
|
|
|
|
|
self.rank = (1, ) + tuple(core.shape[3] for core in cores) |
|
self.shape = (tuple(core.shape[1] for core in cores), |
|
tuple(core.shape[2] for core in cores)) |
|
self.contact = make_contraction(self.shape, self.rank) |
|
|
|
|
|
|
|
self.in_features = np.prod(self.shape[0]) |
|
self.out_features = np.prod(self.shape[1]) |
|
|
|
|
|
self.cores = T.nn.ParameterList(T.nn.Parameter(core) for core in cores) |
|
self.bias = None |
|
if bias is not None: |
|
if bias.size() != self.out_features: |
|
raise ValueError(f'Expected bias size is {self.out_features} ' |
|
f'but its shape is {bias.shape}.') |
|
self.bias = T.nn.Parameter(bias) |
|
|
|
def forward(self, input: T.Tensor) -> T.Tensor: |
|
|
|
|
|
input_shape = input.shape |
|
input = input.reshape(-1, *self.shape[0]) |
|
|
|
|
|
|
|
output = self.contact(input, *self.cores) |
|
output = output.reshape(*input_shape[:-1], self.out_features) |
|
|
|
if self.bias is not None: |
|
output += self.bias |
|
return output |
|
|
|
@classmethod |
|
def from_linear(cls, linear: T.nn.Linear, |
|
shape: Tuple[Tuple[int], Tuple[int]], rank: int, **kwargs): |
|
ndim = len(shape[0]) |
|
|
|
|
|
tt_rank = (1, ) + (rank, ) * (ndim - 1) + (1, ) |
|
tt_shape = tuple(n * m for n, m in zip(*shape)) |
|
|
|
|
|
matrix = linear.weight.data.T |
|
tensor = matrix.reshape(shape[0] + shape[1]) |
|
for i in range(ndim - 1): |
|
tensor = tensor.moveaxis(ndim + i, 2 * i + 1) |
|
|
|
|
|
tensor = tensor.reshape(tt_shape) |
|
cores = ttd(tensor, tt_rank, **kwargs) |
|
|
|
|
|
core_shapes = zip(tt_rank, *shape, tt_rank[1:]) |
|
cores = [core.reshape(core_shape) |
|
for core, core_shape in zip(cores, core_shapes)] |
|
|
|
|
|
bias = None |
|
if linear.bias is not None: |
|
bias = T.clone(linear.bias.data) |
|
|
|
return TTCompressedLinear(cores, bias) |
|
|
|
@classmethod |
|
def from_random(cls, shape: Tuple[Tuple[int], Tuple[int]], rank: int, |
|
bias: bool = True): |
|
tt_ndim = len(shape[0]) |
|
tt_rank = (1, ) + (rank, ) * (tt_ndim - 1) + (1, ) |
|
core_shapes = zip(tt_rank, *shape, tt_rank[1:]) |
|
cores = [T.randn(core_shape) for core_shape in core_shapes] |
|
|
|
bias_term = None |
|
if bias: |
|
out_features = np.prod(shape[1]) |
|
bias_term = T.randn(out_features) |
|
|
|
return TTCompressedLinear(cores, bias_term) |
|
|