bart-base-detox-ttd / modeling_bart.py
not-found's picture
Add TT-compressed model with rank 128
bf24ef5
raw
history blame
2.24 kB
"""This module uses parts of rut5compressed. It shares the same module
structure as model used in neural network compression experiments with
rut5compressed.
"""
from functools import partial
from typing import Optional, Tuple
import numpy as np
import torch as T
from transformers import BartForConditionalGeneration
from .configuration_bart import TTCompressedBartConfig
from .linalg import ttd # noqa: F401 We need this import for HF.
from .modules import TTCompressedLinear
from .util import compress_linear_tt, map_module
class TTCompressedBartForConditionGeneration(BartForConditionalGeneration):
"""Class TTCompressedBartForConditionGeneration defines a BART-based model
with compressed linear layers with TT.
"""
LAYERS = r'/(de|en)coder/layers/\d+/fc[12]'
config_class = TTCompressedBartConfig
def __init__(self, config: TTCompressedBartConfig,
shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None,
rank: Optional[int] = None,
compress: bool = False):
super().__init__(config)
self.rank = rank or config.rank
self.shape = shape
if self.shape is None:
self.shape = (tuple(self.config.shape_in),
tuple(self.config.shape_out))
compress_fn = partial(compress_linear_tt, rank=self.rank)
if not compress:
compress_fn = self.convert
self.model = map_module(self.model, compress_fn, self.LAYERS)
def convert(self, module: T.nn.Module, path: str) -> T.nn.Module:
if isinstance(module, T.nn.Linear):
# If in_features < out_features of original linear module then this
# is extension mapping; otherwise, it is embedding mapping and we
# need to swap input and output shape.
in_shape, out_shape = self.shape
if module.in_features > module.out_features:
out_shape, in_shape = self.shape
shape = (in_shape, out_shape)
bias = module.bias is not None
return TTCompressedLinear.from_random(shape, self.rank, bias)
return module
TTCompressedBartForConditionGeneration \
.register_for_auto_class('AutoModelForSeq2SeqLM')