not-found commited on
Commit
bf24ef5
1 Parent(s): 3061fd9

Add TT-compressed model with rank 128

Browse files
Files changed (8) hide show
  1. README.md +33 -0
  2. config.json +89 -0
  3. configuration_bart.py +20 -0
  4. linalg.py +45 -0
  5. modeling_bart.py +61 -0
  6. modules.py +143 -0
  7. pytorch_model.bin +3 -0
  8. util.py +193 -0
README.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ tags:
5
+ - detoxification
6
+ licenses:
7
+ - cc-by-nc-sa
8
+ pipeline_tag: text2text-generation
9
+ ---
10
+
11
+ **Model Overview**
12
+
13
+ It is a TT-compressed model of original BART-based detoxification model
14
+ [s-nlp/bart-base-detox][1].
15
+
16
+ **How to use**
17
+
18
+ ```python
19
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
20
+ model = AutoModelForSeq2SeqLM \
21
+ .from_pretrained('s-nlp/bart-base-detox-ttd', trust_remote_code=True)
22
+ tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
23
+
24
+ toxics = ['that sick fuck is going to be out in 54 years.']
25
+ tokens = tokenizer(toxics)
26
+ tokens = model.generate(**tokens, num_return_sequences=1, do_sample=False,
27
+ temperature=1.0, repetition_penalty=10.0,
28
+ max_length=128, num_beams=5)
29
+ neutrals = tokenizer.decode(tokens[0, ...], skip_special_tokens=True)
30
+ print(neutrals) # stdout: She is going to be out in 54 years.
31
+ ```
32
+
33
+ [1]: //s-nlp/bart-base-detox
config.json ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/bart-base",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "TTCompressedBartForConditionGeneration"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_bart.TTCompressedBartConfig",
13
+ "AutoModelForSeq2SeqLM": "modeling_bart.TTCompressedBartForConditionGeneration"
14
+ },
15
+ "bos_token_id": 0,
16
+ "classif_dropout": 0.1,
17
+ "classifier_dropout": 0.0,
18
+ "d_model": 768,
19
+ "decoder_attention_heads": 12,
20
+ "decoder_ffn_dim": 3072,
21
+ "decoder_layerdrop": 0.0,
22
+ "decoder_layers": 6,
23
+ "decoder_start_token_id": 2,
24
+ "dropout": 0.1,
25
+ "early_stopping": true,
26
+ "encoder_attention_heads": 12,
27
+ "encoder_ffn_dim": 3072,
28
+ "encoder_layerdrop": 0.0,
29
+ "encoder_layers": 6,
30
+ "eos_token_id": 2,
31
+ "forced_eos_token_id": 2,
32
+ "gradient_checkpointing": false,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1",
36
+ "2": "LABEL_2"
37
+ },
38
+ "init_std": 0.02,
39
+ "is_encoder_decoder": true,
40
+ "label2id": {
41
+ "LABEL_0": 0,
42
+ "LABEL_1": 1,
43
+ "LABEL_2": 2
44
+ },
45
+ "max_position_embeddings": 1024,
46
+ "model_type": "bart",
47
+ "no_repeat_ngram_size": 3,
48
+ "normalize_before": false,
49
+ "normalize_embedding": true,
50
+ "num_beams": 4,
51
+ "num_hidden_layers": 6,
52
+ "pad_token_id": 1,
53
+ "rank": 128,
54
+ "scale_embedding": false,
55
+ "shape_in": [
56
+ 8,
57
+ 8,
58
+ 12
59
+ ],
60
+ "shape_out": [
61
+ 16,
62
+ 16,
63
+ 12
64
+ ],
65
+ "task_specific_params": {
66
+ "summarization": {
67
+ "length_penalty": 1.0,
68
+ "max_length": 128,
69
+ "min_length": 12,
70
+ "num_beams": 4
71
+ },
72
+ "summarization_cnn": {
73
+ "length_penalty": 2.0,
74
+ "max_length": 142,
75
+ "min_length": 56,
76
+ "num_beams": 4
77
+ },
78
+ "summarization_xsum": {
79
+ "length_penalty": 1.0,
80
+ "max_length": 62,
81
+ "min_length": 11,
82
+ "num_beams": 6
83
+ }
84
+ },
85
+ "torch_dtype": "float32",
86
+ "transformers_version": "4.25.1",
87
+ "use_cache": true,
88
+ "vocab_size": 50266
89
+ }
configuration_bart.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ from transformers import BartConfig
4
+
5
+
6
+ class TTCompressedBartConfig(BartConfig):
7
+ """Class TTCompressedBartConfig defines a configuration for TT-compressed
8
+ BART. Here, we split shape to input and output shape in order to serialize
9
+ them to different fields in JSON.
10
+ """
11
+
12
+ def __init__(self, *args, shape_in: Tuple[int] = (),
13
+ shape_out: Tuple[int] = (), rank: int = 128, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.shape_in = shape_in
16
+ self.shape_out = shape_out
17
+ self.rank = rank
18
+
19
+
20
+ TTCompressedBartConfig.register_for_auto_class()
linalg.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Sequence
3
+
4
+ import torch as T
5
+
6
+
7
+ def svd_truncated(mat: T.Tensor, rank: int):
8
+ lvecs, svals, rvecs = T.linalg.svd(mat)
9
+ return lvecs[:, :rank], svals[:rank], rvecs[:rank, :].T
10
+
11
+
12
+ def ttd(ten: T.Tensor, rank: Sequence[int], noiters: int = 1000,
13
+ method: str = 'tsvd') -> Sequence[T.Tensor]:
14
+ """Function ttd implements tensor-train decomposition.
15
+ """
16
+ if ten.ndim + 1 != len(rank):
17
+ raise ValueError
18
+ if rank[0] != 1 or rank[-1] != 1:
19
+ raise ValueError
20
+
21
+ if method == 'svd':
22
+ factorize = svd_truncated
23
+ elif method == 'tsvd':
24
+ factorize = partial(T.svd_lowrank, niter=noiters)
25
+ else:
26
+ raise ValueError(f'Unknown method: {method}.')
27
+
28
+ cores = []
29
+ shape = ten.shape
30
+
31
+ # Iterate over shape of cores and split off core from tensor.
32
+ for core_shape in zip(rank, shape, rank[1:]):
33
+ # breakpoint()
34
+ # Matricization of tensor over the first two axes.
35
+ mat = ten.reshape(core_shape[0] * core_shape[1], -1)
36
+ # Singlular Value Decomposition (SVD).
37
+ lvecs, svals, rvecs = factorize(mat, core_shape[2])
38
+ # Reshape core and rest of tensor.
39
+ core = lvecs * svals[None, :]
40
+ core = core.reshape(core_shape)
41
+ cores.append(core)
42
+ # Use right vectors as a tensor itself.
43
+ ten = rvecs.T
44
+
45
+ return cores
modeling_bart.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module uses parts of rut5compressed. It shares the same module
2
+ structure as model used in neural network compression experiments with
3
+ rut5compressed.
4
+ """
5
+
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch as T
11
+ from transformers import BartForConditionalGeneration
12
+
13
+ from .configuration_bart import TTCompressedBartConfig
14
+ from .linalg import ttd # noqa: F401 We need this import for HF.
15
+ from .modules import TTCompressedLinear
16
+ from .util import compress_linear_tt, map_module
17
+
18
+
19
+ class TTCompressedBartForConditionGeneration(BartForConditionalGeneration):
20
+ """Class TTCompressedBartForConditionGeneration defines a BART-based model
21
+ with compressed linear layers with TT.
22
+ """
23
+
24
+ LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'
25
+
26
+ config_class = TTCompressedBartConfig
27
+
28
+ def __init__(self, config: TTCompressedBartConfig,
29
+ shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None,
30
+ rank: Optional[int] = None,
31
+ compress: bool = False):
32
+ super().__init__(config)
33
+
34
+ self.rank = rank or config.rank
35
+ self.shape = shape
36
+ if self.shape is None:
37
+ self.shape = (tuple(self.config.shape_in),
38
+ tuple(self.config.shape_out))
39
+
40
+ compress_fn = partial(compress_linear_tt, rank=self.rank)
41
+ if not compress:
42
+ compress_fn = self.convert
43
+ self.model = map_module(self.model, compress_fn, self.LAYERS)
44
+
45
+ def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
46
+ if isinstance(module, T.nn.Linear):
47
+ # If in_features < out_features of original linear module then this
48
+ # is extension mapping; otherwise, it is embedding mapping and we
49
+ # need to swap input and output shape.
50
+ in_shape, out_shape = self.shape
51
+ if module.in_features > module.out_features:
52
+ out_shape, in_shape = self.shape
53
+
54
+ shape = (in_shape, out_shape)
55
+ bias = module.bias is not None
56
+ return TTCompressedLinear.from_random(shape, self.rank, bias)
57
+ return module
58
+
59
+
60
+ TTCompressedBartForConditionGeneration \
61
+ .register_for_auto_class('AutoModelForSeq2SeqLM')
modules.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from rut5compressed/nn/modules.py modules of original repository.
2
+
3
+ from typing import Optional, Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import torch as T
7
+ from opt_einsum import contract_expression
8
+ from opt_einsum.contract import ContractExpression
9
+
10
+ from .linalg import ttd
11
+
12
+
13
+ def make_contraction(shape, rank, batch_size=32,
14
+ seqlen=512) -> ContractExpression:
15
+ ndim = len(rank) - 1
16
+ row_shape, col_shape = shape
17
+
18
+ # Generate all contraction indexes.
19
+ row_ix, col_ix = np.arange(2 * ndim).reshape(2, ndim)
20
+ rank_ix = 2 * ndim + np.arange(ndim + 1)
21
+ batch_ix = 4 * ndim # Zero-based index.
22
+
23
+ # Order indexes of cores.
24
+ cores_ix = np.column_stack([rank_ix[:-1], row_ix, col_ix, rank_ix[1:]])
25
+ cores_shape = zip(rank[:-1], row_shape, col_shape, rank[1:])
26
+
27
+ # Order indexes of input (contraction by columns: X G_1 G_2 ... G_d).
28
+ input_ix = np.insert(row_ix, 0, batch_ix)
29
+ input_shape = (batch_size * seqlen, ) + row_shape
30
+
31
+ # Order indexes of output (append rank indexes as well).
32
+ output_ix = np.insert(col_ix, 0, batch_ix)
33
+ output_ix = np.append(output_ix, (rank_ix[0], rank_ix[-1]))
34
+
35
+ # Prepare contraction operands.
36
+ ops = [input_shape, input_ix]
37
+ for core_ix, core_shape in zip(cores_ix, cores_shape):
38
+ ops.append(core_shape)
39
+ ops.append(core_ix)
40
+ ops.append(output_ix)
41
+ ops = [tuple(op) for op in ops]
42
+
43
+ return contract_expression(*ops)
44
+
45
+
46
+ class TTCompressedLinear(T.nn.Module):
47
+ """Class TTCompressedLinear is a layer which represents a weight matrix of
48
+ linear layer in factorized view as tensor train matrix.
49
+
50
+ >>> linear_layer = T.nn.Linear(6, 6)
51
+ >>> tt_layer = TTCompressedLinear \
52
+ ... .from_linear(linear_layer, rank=2, shape=((2, 3), (3, 2)))
53
+ """
54
+
55
+ def __init__(self, cores: Sequence[T.Tensor],
56
+ bias: Optional[T.Tensor] = None):
57
+ super().__init__()
58
+
59
+ for i, core in enumerate(cores):
60
+ if core.ndim != 4:
61
+ raise ValueError('Expected number of dimensions of the '
62
+ f'{i}-th core is 4 but given {cores.ndim}.')
63
+
64
+ # Prepare contaction expression.
65
+ self.rank = (1, ) + tuple(core.shape[3] for core in cores)
66
+ self.shape = (tuple(core.shape[1] for core in cores),
67
+ tuple(core.shape[2] for core in cores))
68
+ self.contact = make_contraction(self.shape, self.rank)
69
+
70
+ # TT-matrix is applied on the left. So, this defines number of input
71
+ # and output features.
72
+ self.in_features = np.prod(self.shape[0])
73
+ self.out_features = np.prod(self.shape[1])
74
+
75
+ # Create trainable variables.
76
+ self.cores = T.nn.ParameterList(T.nn.Parameter(core) for core in cores)
77
+ self.bias = None
78
+ if bias is not None:
79
+ if bias.size() != self.out_features:
80
+ raise ValueError(f'Expected bias size is {self.out_features} '
81
+ f'but its shape is {bias.shape}.')
82
+ self.bias = T.nn.Parameter(bias)
83
+
84
+ def forward(self, input: T.Tensor) -> T.Tensor:
85
+ # We need replace the feature dimension with multi-dimension to contact
86
+ # with TT-matrix.
87
+ input_shape = input.shape
88
+ input = input.reshape(-1, *self.shape[0])
89
+
90
+ # Contract input with weights and replace back multi-dimension with
91
+ # feature dimension.
92
+ output = self.contact(input, *self.cores)
93
+ output = output.reshape(*input_shape[:-1], self.out_features)
94
+
95
+ if self.bias is not None:
96
+ output += self.bias
97
+ return output
98
+
99
+ @classmethod
100
+ def from_linear(cls, linear: T.nn.Linear,
101
+ shape: Tuple[Tuple[int], Tuple[int]], rank: int, **kwargs):
102
+ ndim = len(shape[0])
103
+
104
+ # Prepare information about shape and rank of TT (not TTM).
105
+ tt_rank = (1, ) + (rank, ) * (ndim - 1) + (1, )
106
+ tt_shape = tuple(n * m for n, m in zip(*shape))
107
+
108
+ # Reshape weight matrix to tensor indexes like TT-matrix.
109
+ matrix = linear.weight.data.T
110
+ tensor = matrix.reshape(shape[0] + shape[1])
111
+ for i in range(ndim - 1):
112
+ tensor = tensor.moveaxis(ndim + i, 2 * i + 1)
113
+
114
+ # Reshape TT-matrix to a plain TT and apply decomposition.
115
+ tensor = tensor.reshape(tt_shape)
116
+ cores = ttd(tensor, tt_rank, **kwargs)
117
+
118
+ # Reshape TT-cores back to TT-matrix cores (TTM-cores).
119
+ core_shapes = zip(tt_rank, *shape, tt_rank[1:])
120
+ cores = [core.reshape(core_shape)
121
+ for core, core_shape in zip(cores, core_shapes)]
122
+
123
+ # Make copy of bias if it exists.
124
+ bias = None
125
+ if linear.bias is not None:
126
+ bias = T.clone(linear.bias.data)
127
+
128
+ return TTCompressedLinear(cores, bias)
129
+
130
+ @classmethod
131
+ def from_random(cls, shape: Tuple[Tuple[int], Tuple[int]], rank: int,
132
+ bias: bool = True):
133
+ tt_ndim = len(shape[0])
134
+ tt_rank = (1, ) + (rank, ) * (tt_ndim - 1) + (1, )
135
+ core_shapes = zip(tt_rank, *shape, tt_rank[1:])
136
+ cores = [T.randn(core_shape) for core_shape in core_shapes]
137
+
138
+ bias_term = None
139
+ if bias:
140
+ out_features = np.prod(shape[1])
141
+ bias_term = T.randn(out_features)
142
+
143
+ return TTCompressedLinear(cores, bias_term)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a97ccb2e65c441bf9d23d4e4f48e9e88efe407c54874034acce71c6706c4562e
3
+ size 536167389
util.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from rut5compressed/util.py of rut5compressed repository.
2
+
3
+ import logging
4
+ import re
5
+ from functools import wraps
6
+ from re import Pattern
7
+ from typing import Callable, Dict, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch as T
11
+
12
+ from .modules import TTCompressedLinear
13
+
14
+
15
+ def map_module(root: T.nn.Module,
16
+ func: Callable[[T.nn.Module, str], T.nn.Module],
17
+ patt: Optional[str] = None) -> T.nn.Module:
18
+ """Function ``map_module`` applies a function to each leaf of module tree
19
+ which matches to a specified pattern.
20
+
21
+ Parameters
22
+ ----------
23
+ root : torch.nn.Module
24
+ Module to modify.
25
+ func : callable
26
+ Function to be applied to every module (or matched to pattern) in
27
+ module tree.
28
+ patt : str, optional
29
+ Pattern to filter modules by path in module tree.
30
+
31
+ Returns
32
+ -------
33
+ torch.nn.Module
34
+ Module modified in-place.
35
+ """
36
+ @wraps(func)
37
+ def func_safe(*args, **kwargs):
38
+ node = func(*args, **kwargs)
39
+ if not isinstance(node, T.nn.Module):
40
+ raise ValueError('Mapped result must be toch.nn.Module type '
41
+ f'but given {type(node)}.')
42
+ return node
43
+
44
+ return _map_module(root, func_safe, re.compile(patt or r'.*'), '')
45
+
46
+
47
+ def _map_module(root: T.nn.Module,
48
+ func: Callable[[T.nn.Module, str], T.nn.Module], patt: Pattern,
49
+ path: str) -> T.nn.Module:
50
+ for name, child in root.named_children():
51
+ node = _map_module(child, func, patt, f'{path}/{name}')
52
+ if node != child:
53
+ setattr(root, name, node)
54
+ if patt.match(path or '/'):
55
+ root = func(root, path or '/')
56
+ return root
57
+
58
+
59
+ def convert_linear(module: T.nn.Linear, ctor, **kwargs) -> T.nn.Module:
60
+ """Function convert_linear takes module and returns linear module with
61
+ approximate matmul. Non-linear modules are returned intact.
62
+ """
63
+ if not isinstance(module, T.nn.Linear):
64
+ return module
65
+ raise NotImplementedError
66
+
67
+
68
+ def numel(module: T.nn.Module):
69
+ value = sum(x.numel() for x in module.parameters()) + \
70
+ sum(x.numel() for x in module.buffers())
71
+
72
+ def account_prunned(module: T.nn.Module, path: str):
73
+ nonlocal value
74
+ for name, attr in vars(module).items():
75
+ if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
76
+ continue
77
+
78
+ weight_name = name[:-5]
79
+ if not hasattr(module, weight_name):
80
+ continue
81
+
82
+ weight = getattr(module, weight_name)
83
+ value -= weight.numel() - attr.sum()
84
+ value += attr.numel()
85
+ return module
86
+
87
+ def account_quantized(module: T.nn.Module, path: str):
88
+ nonlocal value
89
+ if isinstance(module, T.nn.quantized.Linear):
90
+ value += module.weight().numel()
91
+ if module.bias() is not None:
92
+ value += module.bias().numel()
93
+ return module
94
+
95
+ def account_rest(module: T.nn.Module, path: str):
96
+ account_prunned(module, path)
97
+ account_quantized(module, path)
98
+ return module
99
+
100
+ map_module(module, account_rest)
101
+ return value
102
+
103
+
104
+ def sizeof(module: T.nn.Module):
105
+ value = sum(x.numel() * x.element_size() for x in module.parameters()) + \
106
+ sum(x.numel() * x.element_size() for x in module.buffers())
107
+
108
+ def account_prunned(module: T.nn.Module, path: str):
109
+ nonlocal value
110
+ for name, attr in vars(module).items():
111
+ if not name.endswith('_mask') or not isinstance(attr, T.Tensor):
112
+ continue
113
+
114
+ weight_name = name[:-5]
115
+ if not hasattr(module, weight_name):
116
+ continue
117
+
118
+ weight = getattr(module, weight_name)
119
+ value -= (weight.numel() - attr.sum()) * weight.element_size()
120
+ value += attr.numel() * attr.element_size()
121
+ return module
122
+
123
+ def account_quantized(module: T.nn.Module, path: str):
124
+ nonlocal value
125
+ if isinstance(module, T.nn.quantized.Linear):
126
+ value += module.weight().numel() * module.weight().element_size()
127
+ if (bias := module.bias()) is not None:
128
+ value += bias.numel() * bias.element_size()
129
+ return module
130
+
131
+ def account_rest(module: T.nn.Module, path: str):
132
+ account_prunned(module, path)
133
+ account_quantized(module, path)
134
+ return module
135
+
136
+ map_module(module, account_rest)
137
+ return value
138
+
139
+
140
+ def flatten_module(module: T.nn.Module, regexp=None) -> Dict[str, T.nn.Module]:
141
+ modules = {}
142
+ map_module(module, lambda x, y: modules.update(**{y: x}) or x, regexp)
143
+ return modules
144
+
145
+
146
+ def print_flatten(module: T.nn.Module):
147
+ paths = []
148
+ path_len = 0
149
+ names = []
150
+ name_len = 0
151
+ indx_len = 0
152
+
153
+ def func(module, path):
154
+ nonlocal path_len, name_len, indx_len
155
+ paths.append(path)
156
+ path_len = max(path_len, len(path))
157
+ name = module.__class__.__name__
158
+ names.append(name)
159
+ name_len = max(name_len, len(name))
160
+ indx_len += 1
161
+ return module
162
+
163
+ map_module(module, func)
164
+
165
+ indx_len = int(np.ceil(np.log10(indx_len)))
166
+ fmt = f'{{indx:>{indx_len}s}} {{path:{path_len}s}} {{name:{name_len}s}}'
167
+ print(fmt.format(indx='#', path='Path', name='Layer'))
168
+ print('-' * (indx_len + path_len + name_len + 2))
169
+ for i, (path, name) in enumerate(zip(paths, names)):
170
+ print(fmt.format(indx=str(i), path=path, name=name))
171
+
172
+
173
+ def compress_linear_tt(module: T.nn.Module, path: str,
174
+ shape: Tuple[Tuple[int], Tuple[int]],
175
+ rank: int) -> T.nn.Module:
176
+ if not isinstance(module, T.nn.Linear):
177
+ return module
178
+
179
+ # TODO(@not-found): We need propper compression config.
180
+ inp_size = np.prod(shape[0])
181
+ out_size = np.prod(shape[1])
182
+ if inp_size == module.in_features and out_size == module.out_features:
183
+ pass
184
+ elif inp_size == module.out_features and out_size == module.in_features:
185
+ shape = (shape[1], shape[0])
186
+ else:
187
+ raise ValueError(
188
+ 'Input and output features does not match to compression shape: '
189
+ f'{shape[0]} vs {module.in_features} and {shape[1]} vs '
190
+ f'{module.out_features}.')
191
+
192
+ logging.info('apply tt compression to layer %s', path)
193
+ return TTCompressedLinear.from_linear(module, shape, rank) # noqa: F821