File size: 5,329 Bytes
4cacee8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copied from rut5compressed/nn/modules.py modules of original repository.
from typing import Optional, Sequence, Tuple
import numpy as np
import torch as T
from opt_einsum import contract_expression
from opt_einsum.contract import ContractExpression
from .linalg import ttd
def make_contraction(shape, rank, batch_size=32,
seqlen=512) -> ContractExpression:
ndim = len(rank) - 1
row_shape, col_shape = shape
# Generate all contraction indexes.
row_ix, col_ix = np.arange(2 * ndim).reshape(2, ndim)
rank_ix = 2 * ndim + np.arange(ndim + 1)
batch_ix = 4 * ndim # Zero-based index.
# Order indexes of cores.
cores_ix = np.column_stack([rank_ix[:-1], row_ix, col_ix, rank_ix[1:]])
cores_shape = zip(rank[:-1], row_shape, col_shape, rank[1:])
# Order indexes of input (contraction by columns: X G_1 G_2 ... G_d).
input_ix = np.insert(row_ix, 0, batch_ix)
input_shape = (batch_size * seqlen, ) + row_shape
# Order indexes of output (append rank indexes as well).
output_ix = np.insert(col_ix, 0, batch_ix)
output_ix = np.append(output_ix, (rank_ix[0], rank_ix[-1]))
# Prepare contraction operands.
ops = [input_shape, input_ix]
for core_ix, core_shape in zip(cores_ix, cores_shape):
ops.append(core_shape)
ops.append(core_ix)
ops.append(output_ix)
ops = [tuple(op) for op in ops]
return contract_expression(*ops)
class TTCompressedLinear(T.nn.Module):
"""Class TTCompressedLinear is a layer which represents a weight matrix of
linear layer in factorized view as tensor train matrix.
>>> linear_layer = T.nn.Linear(6, 6)
>>> tt_layer = TTCompressedLinear \
... .from_linear(linear_layer, rank=2, shape=((2, 3), (3, 2)))
"""
def __init__(self, cores: Sequence[T.Tensor],
bias: Optional[T.Tensor] = None):
super().__init__()
for i, core in enumerate(cores):
if core.ndim != 4:
raise ValueError('Expected number of dimensions of the '
f'{i}-th core is 4 but given {cores.ndim}.')
# Prepare contaction expression.
self.rank = (1, ) + tuple(core.shape[3] for core in cores)
self.shape = (tuple(core.shape[1] for core in cores),
tuple(core.shape[2] for core in cores))
self.contact = make_contraction(self.shape, self.rank)
# TT-matrix is applied on the left. So, this defines number of input
# and output features.
self.in_features = np.prod(self.shape[0])
self.out_features = np.prod(self.shape[1])
# Create trainable variables.
self.cores = T.nn.ParameterList(T.nn.Parameter(core) for core in cores)
self.bias = None
if bias is not None:
if bias.size() != self.out_features:
raise ValueError(f'Expected bias size is {self.out_features} '
f'but its shape is {bias.shape}.')
self.bias = T.nn.Parameter(bias)
def forward(self, input: T.Tensor) -> T.Tensor:
# We need replace the feature dimension with multi-dimension to contact
# with TT-matrix.
input_shape = input.shape
input = input.reshape(-1, *self.shape[0])
# Contract input with weights and replace back multi-dimension with
# feature dimension.
output = self.contact(input, *self.cores)
output = output.reshape(*input_shape[:-1], self.out_features)
if self.bias is not None:
output += self.bias
return output
@classmethod
def from_linear(cls, linear: T.nn.Linear,
shape: Tuple[Tuple[int], Tuple[int]], rank: int, **kwargs):
ndim = len(shape[0])
# Prepare information about shape and rank of TT (not TTM).
tt_rank = (1, ) + (rank, ) * (ndim - 1) + (1, )
tt_shape = tuple(n * m for n, m in zip(*shape))
# Reshape weight matrix to tensor indexes like TT-matrix.
matrix = linear.weight.data.T
tensor = matrix.reshape(shape[0] + shape[1])
for i in range(ndim - 1):
tensor = tensor.moveaxis(ndim + i, 2 * i + 1)
# Reshape TT-matrix to a plain TT and apply decomposition.
tensor = tensor.reshape(tt_shape)
cores = ttd(tensor, tt_rank, **kwargs)
# Reshape TT-cores back to TT-matrix cores (TTM-cores).
core_shapes = zip(tt_rank, *shape, tt_rank[1:])
cores = [core.reshape(core_shape)
for core, core_shape in zip(cores, core_shapes)]
# Make copy of bias if it exists.
bias = None
if linear.bias is not None:
bias = T.clone(linear.bias.data)
return TTCompressedLinear(cores, bias)
@classmethod
def from_random(cls, shape: Tuple[Tuple[int], Tuple[int]], rank: int,
bias: bool = True):
tt_ndim = len(shape[0])
tt_rank = (1, ) + (rank, ) * (tt_ndim - 1) + (1, )
core_shapes = zip(tt_rank, *shape, tt_rank[1:])
cores = [T.randn(core_shape) for core_shape in core_shapes]
bias_term = None
if bias:
out_features = np.prod(shape[1])
bias_term = T.randn(out_features)
return TTCompressedLinear(cores, bias_term)
|