bart-base-detox-ttd / modules.py
not-found's picture
Add TT-compressed model with rank 128
bf24ef5
# Copied from rut5compressed/nn/modules.py modules of original repository.
from typing import Optional, Sequence, Tuple
import numpy as np
import torch as T
from opt_einsum import contract_expression
from opt_einsum.contract import ContractExpression
from .linalg import ttd
def make_contraction(shape, rank, batch_size=32,
seqlen=512) -> ContractExpression:
ndim = len(rank) - 1
row_shape, col_shape = shape
# Generate all contraction indexes.
row_ix, col_ix = np.arange(2 * ndim).reshape(2, ndim)
rank_ix = 2 * ndim + np.arange(ndim + 1)
batch_ix = 4 * ndim # Zero-based index.
# Order indexes of cores.
cores_ix = np.column_stack([rank_ix[:-1], row_ix, col_ix, rank_ix[1:]])
cores_shape = zip(rank[:-1], row_shape, col_shape, rank[1:])
# Order indexes of input (contraction by columns: X G_1 G_2 ... G_d).
input_ix = np.insert(row_ix, 0, batch_ix)
input_shape = (batch_size * seqlen, ) + row_shape
# Order indexes of output (append rank indexes as well).
output_ix = np.insert(col_ix, 0, batch_ix)
output_ix = np.append(output_ix, (rank_ix[0], rank_ix[-1]))
# Prepare contraction operands.
ops = [input_shape, input_ix]
for core_ix, core_shape in zip(cores_ix, cores_shape):
ops.append(core_shape)
ops.append(core_ix)
ops.append(output_ix)
ops = [tuple(op) for op in ops]
return contract_expression(*ops)
class TTCompressedLinear(T.nn.Module):
"""Class TTCompressedLinear is a layer which represents a weight matrix of
linear layer in factorized view as tensor train matrix.
>>> linear_layer = T.nn.Linear(6, 6)
>>> tt_layer = TTCompressedLinear \
... .from_linear(linear_layer, rank=2, shape=((2, 3), (3, 2)))
"""
def __init__(self, cores: Sequence[T.Tensor],
bias: Optional[T.Tensor] = None):
super().__init__()
for i, core in enumerate(cores):
if core.ndim != 4:
raise ValueError('Expected number of dimensions of the '
f'{i}-th core is 4 but given {cores.ndim}.')
# Prepare contaction expression.
self.rank = (1, ) + tuple(core.shape[3] for core in cores)
self.shape = (tuple(core.shape[1] for core in cores),
tuple(core.shape[2] for core in cores))
self.contact = make_contraction(self.shape, self.rank)
# TT-matrix is applied on the left. So, this defines number of input
# and output features.
self.in_features = np.prod(self.shape[0])
self.out_features = np.prod(self.shape[1])
# Create trainable variables.
self.cores = T.nn.ParameterList(T.nn.Parameter(core) for core in cores)
self.bias = None
if bias is not None:
if bias.size() != self.out_features:
raise ValueError(f'Expected bias size is {self.out_features} '
f'but its shape is {bias.shape}.')
self.bias = T.nn.Parameter(bias)
def forward(self, input: T.Tensor) -> T.Tensor:
# We need replace the feature dimension with multi-dimension to contact
# with TT-matrix.
input_shape = input.shape
input = input.reshape(-1, *self.shape[0])
# Contract input with weights and replace back multi-dimension with
# feature dimension.
output = self.contact(input, *self.cores)
output = output.reshape(*input_shape[:-1], self.out_features)
if self.bias is not None:
output += self.bias
return output
@classmethod
def from_linear(cls, linear: T.nn.Linear,
shape: Tuple[Tuple[int], Tuple[int]], rank: int, **kwargs):
ndim = len(shape[0])
# Prepare information about shape and rank of TT (not TTM).
tt_rank = (1, ) + (rank, ) * (ndim - 1) + (1, )
tt_shape = tuple(n * m for n, m in zip(*shape))
# Reshape weight matrix to tensor indexes like TT-matrix.
matrix = linear.weight.data.T
tensor = matrix.reshape(shape[0] + shape[1])
for i in range(ndim - 1):
tensor = tensor.moveaxis(ndim + i, 2 * i + 1)
# Reshape TT-matrix to a plain TT and apply decomposition.
tensor = tensor.reshape(tt_shape)
cores = ttd(tensor, tt_rank, **kwargs)
# Reshape TT-cores back to TT-matrix cores (TTM-cores).
core_shapes = zip(tt_rank, *shape, tt_rank[1:])
cores = [core.reshape(core_shape)
for core, core_shape in zip(cores, core_shapes)]
# Make copy of bias if it exists.
bias = None
if linear.bias is not None:
bias = T.clone(linear.bias.data)
return TTCompressedLinear(cores, bias)
@classmethod
def from_random(cls, shape: Tuple[Tuple[int], Tuple[int]], rank: int,
bias: bool = True):
tt_ndim = len(shape[0])
tt_rank = (1, ) + (rank, ) * (tt_ndim - 1) + (1, )
core_shapes = zip(tt_rank, *shape, tt_rank[1:])
cores = [T.randn(core_shape) for core_shape in core_shapes]
bias_term = None
if bias:
out_features = np.prod(shape[1])
bias_term = T.randn(out_features)
return TTCompressedLinear(cores, bias_term)