File size: 5,329 Bytes
4cacee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copied from rut5compressed/nn/modules.py modules of original repository.

from typing import Optional, Sequence, Tuple

import numpy as np
import torch as T
from opt_einsum import contract_expression
from opt_einsum.contract import ContractExpression

from .linalg import ttd


def make_contraction(shape, rank, batch_size=32,
                     seqlen=512) -> ContractExpression:
    ndim = len(rank) - 1
    row_shape, col_shape = shape

    # Generate all contraction indexes.
    row_ix, col_ix = np.arange(2 * ndim).reshape(2, ndim)
    rank_ix = 2 * ndim + np.arange(ndim + 1)
    batch_ix = 4 * ndim  # Zero-based index.

    # Order indexes of cores.
    cores_ix = np.column_stack([rank_ix[:-1], row_ix, col_ix, rank_ix[1:]])
    cores_shape = zip(rank[:-1], row_shape, col_shape, rank[1:])

    # Order indexes of input (contraction by columns: X G_1 G_2 ... G_d).
    input_ix = np.insert(row_ix, 0, batch_ix)
    input_shape = (batch_size * seqlen, ) + row_shape

    # Order indexes of output (append rank indexes as well).
    output_ix = np.insert(col_ix, 0, batch_ix)
    output_ix = np.append(output_ix, (rank_ix[0], rank_ix[-1]))

    # Prepare contraction operands.
    ops = [input_shape, input_ix]
    for core_ix, core_shape in zip(cores_ix, cores_shape):
        ops.append(core_shape)
        ops.append(core_ix)
    ops.append(output_ix)
    ops = [tuple(op) for op in ops]

    return contract_expression(*ops)


class TTCompressedLinear(T.nn.Module):
    """Class TTCompressedLinear is a layer which represents a weight matrix of
    linear layer in factorized view as tensor train matrix.

    >>> linear_layer = T.nn.Linear(6, 6)
    >>> tt_layer = TTCompressedLinear \
    ...     .from_linear(linear_layer, rank=2, shape=((2, 3), (3, 2)))
    """

    def __init__(self, cores: Sequence[T.Tensor],
                 bias: Optional[T.Tensor] = None):
        super().__init__()

        for i, core in enumerate(cores):
            if core.ndim != 4:
                raise ValueError('Expected number of dimensions of the '
                                 f'{i}-th core is 4 but given {cores.ndim}.')

        # Prepare contaction expression.
        self.rank = (1, ) + tuple(core.shape[3] for core in cores)
        self.shape = (tuple(core.shape[1] for core in cores),
                      tuple(core.shape[2] for core in cores))
        self.contact = make_contraction(self.shape, self.rank)

        # TT-matrix is applied on the left. So, this defines number of input
        # and output features.
        self.in_features = np.prod(self.shape[0])
        self.out_features = np.prod(self.shape[1])

        # Create trainable variables.
        self.cores = T.nn.ParameterList(T.nn.Parameter(core) for core in cores)
        self.bias = None
        if bias is not None:
            if bias.size() != self.out_features:
                raise ValueError(f'Expected bias size is {self.out_features} '
                                 f'but its shape is {bias.shape}.')
            self.bias = T.nn.Parameter(bias)

    def forward(self, input: T.Tensor) -> T.Tensor:
        # We need replace the feature dimension with multi-dimension to contact
        # with TT-matrix.
        input_shape = input.shape
        input = input.reshape(-1, *self.shape[0])

        # Contract input with weights and replace back multi-dimension with
        # feature dimension.
        output = self.contact(input, *self.cores)
        output = output.reshape(*input_shape[:-1], self.out_features)

        if self.bias is not None:
            output += self.bias
        return output

    @classmethod
    def from_linear(cls, linear: T.nn.Linear,
                    shape: Tuple[Tuple[int], Tuple[int]], rank: int, **kwargs):
        ndim = len(shape[0])

        # Prepare information about shape and rank of TT (not TTM).
        tt_rank = (1, ) + (rank, ) * (ndim - 1) + (1, )
        tt_shape = tuple(n * m for n, m in zip(*shape))

        # Reshape weight matrix to tensor indexes like TT-matrix.
        matrix = linear.weight.data.T
        tensor = matrix.reshape(shape[0] + shape[1])
        for i in range(ndim - 1):
            tensor = tensor.moveaxis(ndim + i, 2 * i + 1)

        # Reshape TT-matrix to a plain TT and apply decomposition.
        tensor = tensor.reshape(tt_shape)
        cores = ttd(tensor, tt_rank, **kwargs)

        # Reshape TT-cores back to TT-matrix cores (TTM-cores).
        core_shapes = zip(tt_rank, *shape, tt_rank[1:])
        cores = [core.reshape(core_shape)
                 for core, core_shape in zip(cores, core_shapes)]

        # Make copy of bias if it exists.
        bias = None
        if linear.bias is not None:
            bias = T.clone(linear.bias.data)

        return TTCompressedLinear(cores, bias)

    @classmethod
    def from_random(cls, shape: Tuple[Tuple[int], Tuple[int]], rank: int,
                    bias: bool = True):
        tt_ndim = len(shape[0])
        tt_rank = (1, ) + (rank, ) * (tt_ndim - 1) + (1, )
        core_shapes = zip(tt_rank, *shape, tt_rank[1:])
        cores = [T.randn(core_shape) for core_shape in core_shapes]

        bias_term = None
        if bias:
            out_features = np.prod(shape[1])
            bias_term = T.randn(out_features)

        return TTCompressedLinear(cores, bias_term)