from typing import Tuple from transformers import BartConfig class TTCompressedBartConfig(BartConfig): """Class TTCompressedBartConfig defines a configuration for TT-compressed BART. Here, we split shape to input and output shape in order to serialize them to different fields in JSON. """ def __init__(self, *args, shape_in: Tuple[int] = (), shape_out: Tuple[int] = (), rank: int = 128, **kwargs): super().__init__(*args, **kwargs) self.shape_in = shape_in self.shape_out = shape_out self.rank = rank TTCompressedBartConfig.register_for_auto_class()