from typing import Tuple | |
from transformers import BertConfig | |
class TTCompressedBertConfig(BertConfig): | |
"""Class TTCompressedBertConfig defines a configuration for TT-compressed | |
BERT. Here, we split shape to input and output shape in order to serialize | |
them to different fields in JSON. | |
""" | |
def __init__(self, *args, shape_in: Tuple[int] = (), | |
shape_out: Tuple[int] = (), rank: int = 128, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.shape_in = shape_in | |
self.shape_out = shape_out | |
self.rank = rank | |
TTCompressedBertConfig.register_for_auto_class() | |