File size: 626 Bytes
4cacee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Tuple

from transformers import BartConfig


class TTCompressedBartConfig(BartConfig):
    """Class TTCompressedBartConfig defines a configuration for TT-compressed
    BART. Here, we split shape to input and output shape in order to serialize
    them to different fields in JSON.
    """

    def __init__(self, *args, shape_in: Tuple[int] = (),
                 shape_out: Tuple[int] = (), rank: int = 128, **kwargs):
        super().__init__(*args, **kwargs)
        self.shape_in = shape_in
        self.shape_out = shape_out
        self.rank = rank


TTCompressedBartConfig.register_for_auto_class()