"""This module uses parts of rut5compressed. It shares the same module structure as model used in neural network compression experiments with rut5compressed. """ from functools import partial from typing import Optional, Tuple import numpy as np import torch as T from transformers import BartForConditionalGeneration from .configuration_bart import TTCompressedBartConfig from .linalg import ttd # noqa: F401 We need this import for HF. from .modules import TTCompressedLinear from .util import compress_linear_tt, map_module class TTCompressedBartForConditionGeneration(BartForConditionalGeneration): """Class TTCompressedBartForConditionGeneration defines a BART-based model with compressed linear layers with TT. """ LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' config_class = TTCompressedBartConfig def __init__(self, config: TTCompressedBartConfig, shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None, rank: Optional[int] = None, compress: bool = False): super().__init__(config) self.rank = rank or config.rank self.shape = shape if self.shape is None: self.shape = (tuple(self.config.shape_in), tuple(self.config.shape_out)) compress_fn = partial(compress_linear_tt, rank=self.rank) if not compress: compress_fn = self.convert self.model = map_module(self.model, compress_fn, self.LAYERS) def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: if isinstance(module, T.nn.Linear): # If in_features < out_features of original linear module then this # is extension mapping; otherwise, it is embedding mapping and we # need to swap input and output shape. in_shape, out_shape = self.shape if module.in_features > module.out_features: out_shape, in_shape = self.shape shape = (in_shape, out_shape) bias = module.bias is not None return TTCompressedLinear.from_random(shape, self.rank, bias) return module TTCompressedBartForConditionGeneration \ .register_for_auto_class('AutoModelForSeq2SeqLM')