File size: 626 Bytes
4cacee8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from typing import Tuple
from transformers import BartConfig
class TTCompressedBartConfig(BartConfig):
"""Class TTCompressedBartConfig defines a configuration for TT-compressed
BART. Here, we split shape to input and output shape in order to serialize
them to different fields in JSON.
"""
def __init__(self, *args, shape_in: Tuple[int] = (),
shape_out: Tuple[int] = (), rank: int = 128, **kwargs):
super().__init__(*args, **kwargs)
self.shape_in = shape_in
self.shape_out = shape_out
self.rank = rank
TTCompressedBartConfig.register_for_auto_class()
|