|
"""This module uses parts of rut5compressed. It shares the same module |
|
structure as model used in neural network compression experiments with |
|
rut5compressed. |
|
""" |
|
|
|
from functools import partial |
|
from typing import Optional, Tuple |
|
|
|
import numpy as np |
|
import torch as T |
|
from transformers import BartForConditionalGeneration |
|
|
|
from .configuration_bart import TTCompressedBartConfig |
|
from .linalg import ttd |
|
from .modules import TTCompressedLinear |
|
from .util import compress_linear_tt, map_module |
|
|
|
|
|
class TTCompressedBartForConditionGeneration(BartForConditionalGeneration): |
|
"""Class TTCompressedBartForConditionGeneration defines a BART-based model |
|
with compressed linear layers with TT. |
|
""" |
|
|
|
LAYERS = r'/(de|en)coder/layers/\d+/fc[12]' |
|
|
|
config_class = TTCompressedBartConfig |
|
|
|
def __init__(self, config: TTCompressedBartConfig, |
|
shape: Optional[Tuple[Tuple[int], Tuple[int]]] = None, |
|
rank: Optional[int] = None, |
|
compress: bool = False): |
|
super().__init__(config) |
|
|
|
self.rank = rank or config.rank |
|
self.shape = shape |
|
if self.shape is None: |
|
self.shape = (tuple(self.config.shape_in), |
|
tuple(self.config.shape_out)) |
|
|
|
compress_fn = partial(compress_linear_tt, rank=self.rank) |
|
if not compress: |
|
compress_fn = self.convert |
|
self.model = map_module(self.model, compress_fn, self.LAYERS) |
|
|
|
def convert(self, module: T.nn.Module, path: str) -> T.nn.Module: |
|
if isinstance(module, T.nn.Linear): |
|
|
|
|
|
|
|
in_shape, out_shape = self.shape |
|
if module.in_features > module.out_features: |
|
out_shape, in_shape = self.shape |
|
|
|
shape = (in_shape, out_shape) |
|
bias = module.bias is not None |
|
return TTCompressedLinear.from_random(shape, self.rank, bias) |
|
return module |
|
|
|
|
|
TTCompressedBartForConditionGeneration \ |
|
.register_for_auto_class('AutoModelForSeq2SeqLM') |
|
|