File size: 2,262 Bytes
4cacee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""This module uses parts of rut5compressed. It shares the same module
structure as model used in neural network compression experiments with
rut5compressed.
"""

from functools import partial
from typing import Optional, Tuple

import numpy as np
import torch as T
from transformers import BartForConditionalGeneration

from .configuration_bart import TTCompressedBartConfig
from .linalg import ttd  # noqa: F401 We need this import for HF.
from .modules import TTCompressedLinear
from .util import compress_linear_tt, map_module


class TTCompressedBartForConditionGeneration(BartForConditionalGeneration):
    """Class TTCompressedBartForConditionGeneration defines a BART-based model
    with compressed linear layers with TT.
    """

    LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'

    config_class = TTCompressedBartConfig

    def __init__(self, config: TTCompressedBartConfig,
                 shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None,
                 rank: Optional[int] = None,
                 compress: bool = False):
        super().__init__(config)

        self.rank = rank or config.rank
        self.shape = shape
        if self.shape is None:
            self.shape = (tuple(self.config.shape_in),
                          tuple(self.config.shape_out))

        compress_fn = partial(compress_linear_tt, rank=self.rank, shape=self.shape)
        if not compress:
            compress_fn = self.convert
        self.model = map_module(self.model, compress_fn, self.LAYERS)

    def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
        if isinstance(module, T.nn.Linear):
            # If in_features < out_features of original linear module then this
            # is extension mapping; otherwise, it is embedding mapping and we
            # need to swap input and output shape.
            in_shape, out_shape = self.shape
            if module.in_features > module.out_features:
                out_shape, in_shape = self.shape

            shape = (in_shape, out_shape)
            bias = module.bias is not None
            return TTCompressedLinear.from_random(shape, self.rank, bias)
        return module


TTCompressedBartForConditionGeneration \
    .register_for_auto_class('AutoModelForSeq2SeqLM')