File size: 14,147 Bytes
989ac62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import math
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F

from .parametrized_layer import Parametrization
from .utils import use_init_empty_weights

logger = getLogger(__name__)


class CompressionCriterion(ABC):
    """
    Abstract class for compression criterion of a (target) parameter of a parametrized module.
    """

    @abstractmethod
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: A tensor of any shape

        Returns: A boolean mask of the same shape as `x` where `False` indicates that the entry can be removed.
        """
        raise NotImplementedError


class ThresholdCriterion(CompressionCriterion):
    """
    Compression criterion based on a threshold. All entries below `self.threshold` can be removed.
    """

    def __init__(self, threshold: float = 0.0):
        self.threshold = threshold

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return x > self.threshold


class ProjectedLinearParametrization(Parametrization, ABC):
    """
    Implementation of a linear layer parametrization, factorizing the weight matrix as
    `weight = ortho.weight @ torch.diag(mask) @ base.weight`.
    Here, `ortho` is a linear layer with orthogonal columns, `mask` represents a (binary) diagonal matrix
    that can be pruned, and `base` is a linear layer (determined by the choice of `ortho`).
    Any child class needs to implement `_ortho_init` which creates `ortho`. Based on this, `mask` and `base` are
    initialized such that the original weight matrix is obtained at initialization.

    `mask` corresponds to the only target parameter of this parametrization. Pruning it will result in
    a low-rank matrix representation of the parametrized linear module.
    """

    base_class = nn.Linear

    def __init__(
        self,
        mask_func: Literal["ste", "relu", "none"] = "ste",
        mask_scaling_factor: float | str = "norm",
        compression_criterion: CompressionCriterion = ThresholdCriterion(),
    ):
        """
        Args:
            mask_func: A function applied to the mask parameter in each forward pass implementing
                custom functionalities. Available options: ["ste", "relu", "none"].
                "ste" means using a straight-through estimator, i.e., in the forward pass, `mask` is binarized, which
                is ignored in the backward pass. Before `mask` passed through a ReLU activation.
                "relu" means that `mask` is passed through a ReLU activation.
                "none" means that `mask` is not modified.
            mask_scaling_factor: Conceptually, `mask` is initialized with ones, but rescaling to a smaller value
                can vastly improve the training speed. `mask_scaling_factor` specifies this rescaling factor.
                The rescaling should be compensated by scaling `ortho` accordingly in `self._ortho_init`.
                If `mask_scaling_factor='norm'`, the scaling factor is chosen such that `mask` has unit L2 norm
                (note that this can lead to a different behavior in model tuning than for a fixed factor
                 when some target parameters have different number of elements).
            compression_criterion: `CompressionCriterion` to be used in `self.reset_target_params(mode="compress")`.
        """
        super().__init__()
        self.mask_func = {
            "ste": mask_func_ste,
            "relu": mask_func_relu,
            "none": mask_func_none,
        }[mask_func]
        self._mask_scaling_factor = mask_scaling_factor
        self.compression_criterion = compression_criterion

    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        # This implementation avoids an explicit materalization of `weight`.
        x = self.base(x)
        x = self.mask_func(self.mask, self.mask_scaling_factor) * x
        x = self.ortho(x)
        return x

    def _weight(self) -> torch.Tensor:
        # Compute the original weight matrix, don't use this in forward pass for efficiency reasons
        mask = self.mask_func(self.mask, self.mask_scaling_factor)
        return self.ortho.weight @ torch.diag(mask) @ self.base.weight

    def _bias(self) -> torch.Tensor | None:
        return self.ortho.bias

    def _initialize(self, base_module: base_class) -> None:
        factory_kwargs = {"device": base_module.weight.device, "dtype": base_module.weight.dtype}
        in_dim, out_dim = base_module.in_features, base_module.out_features
        proj_dim = min(in_dim, out_dim)  # infer mask (bottleneck) dimension

        # Initialize ortho layer ....
        self.add_module(
            "ortho",
            nn.Linear(in_features=proj_dim, out_features=out_dim, bias=base_module.bias is not None, **factory_kwargs),
        )
        self._ortho_init(base_module.weight)
        if base_module.bias is not None:
            # It is important that ortho carries the bias (and not base) because ortho is used to compute the final
            # output of the forward pass
            self.ortho.bias.data.copy_(base_module.bias.data)

        # ... and compute the base layer based on the choice of ortho (this only works of ortho has orthogonal columns)
        base = base_module.__class__(in_features=in_dim, out_features=proj_dim, bias=False, **factory_kwargs)
        base.weight.data.copy_(self.ortho.weight.data.T @ base_module.weight.data)
        self.add_module("base", base)

        # Creating (tunable) mask parameter ...
        self.register_parameter("mask", torch.nn.Parameter(torch.ones(proj_dim, **factory_kwargs)))
        # ... and rescale mask properly in a separate step
        # (because reset_target_params calls mask_scaling_factor, which in turn may require mask to already exist)
        self.reset_target_params()

    @abstractmethod
    def _ortho_init(self, weight: torch.Tensor) -> None:
        """
        Initialize ortho layer. Must be implemented by child class.

        Args:
            weight: Weight matrix of the original linear layer module.
        """
        raise NotImplementedError

    def get_target_params(self) -> dict[str, torch.nn.Parameter]:
        return {"mask": self.mask}

    @property
    def mask_scaling_factor(self) -> float:
        if self._mask_scaling_factor == "norm":
            # Choose scaling factor such that mask has unit L2 norm.
            # Note: mask already needs to exist at this point to infer its shape.
            self._mask_scaling_factor = 1 / math.sqrt(self.mask.numel())
            return self._mask_scaling_factor
        elif isinstance(self._mask_scaling_factor, float):
            return self._mask_scaling_factor
        else:
            raise ValueError(f"Invalid mask_scaling_factor: {self._mask_scaling_factor}")

    @property
    def in_features(self) -> int:
        return self.base.in_features

    @property
    def out_features(self) -> int:
        return self.ortho.out_features

    def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
        with torch.no_grad():
            if mode == "full":
                # Scale mask values properly by self.mask_scaling_factor
                self.mask.data = torch.ones_like(self.mask.data) * self.mask_scaling_factor
            elif mode == "nonzero":
                # Scale mask values properly by self.mask_scaling_factor
                self.mask.data[self.mask.data > 0] = 1.0 * self.mask_scaling_factor
                self.mask.data[self.mask.data < 0] = 0.0
            elif mode == "compress":
                if self.compression_criterion is None:
                    logger.warning("Compression criterion is not set. No op...")
                    return
                # Select entries of parameter mask that should be kept
                dim_select = self.compression_criterion(self.mask)
                # Create and register compressed layers and mask
                new_base = new_linear_from_mask(self.base, dim_select, column_select=False)
                new_ortho = new_linear_from_mask(self.ortho, dim_select, column_select=True)
                new_mask = self.mask[dim_select].clone().detach()
                del self.mask, self.base, self.ortho
                self.register_module("base", new_base)
                self.register_module("ortho", new_ortho)
                self.register_parameter("mask", nn.Parameter(new_mask))
            else:
                raise ValueError(f"Invalid mode: {mode}")

    def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int:
        if not compressed:
            # Compute number of parameters for full linear layer
            num_params = self.in_features * self.out_features
            if self.bias is not None:
                num_params += self.out_features
            return num_params
        else:
            # Compute number of mask values that could be discarded by self.reset_target_params(mode="compress") ...
            if target_params is not None:
                sparsity = mask_sparsity(target_params["mask"] != 0.0, threshold=0.0)
            else:
                sparsity = mask_sparsity(self.mask)
            # ... and compute the (hypothetical) number of parameters for a compressed module.
            num_params = self.in_features * sparsity + sparsity * self.out_features
            if self.bias is not None:
                num_params += self.out_features
            # If the number of parameters for the compressed module would be larger than the number of parameters
            # for the full module, return the latter because we can always unparametrize to the original module if
            # compression would not be effective.
            num_params = min(self.get_num_params(compressed=False), num_params)
            return num_params


class SVDLinearParametrization(ProjectedLinearParametrization):
    """
    Implementation of a linear layer parametrization using SVD decomposition.
    If the SVD of weight is U * S * V^T, then `ortho.weight = U` and `base.weight = S * V^T`.
    As base is computed automatically by `_initialize`, `_ortho_init` only needs to compute U and
    scale it properly with `mask_scaling_factor`. The singular values S are buffered just in case they are needed
    in the tuning process.
    """

    def _ortho_init(self, weight: torch.Tensor) -> None:
        k = min(weight.shape[0], weight.shape[1])
        if use_init_empty_weights.get():
            # Check if the init_empty_weights context is active which avoids a (costly) SVD computation and just
            # initializes U and S as empty tensors. They are loaded later from a pretrained model.
            logger.debug("Parametrizing with empty weights.")
            U = torch.empty(weight.shape[0], k)
            S = torch.empty(k, 1)
        else:
            # Detaching is important to avoid memory leaks. torch.linalg.svd only works with float32.
            U, S, _ = torch.linalg.svd(weight.detach().float(), full_matrices=False)
            # Rescaling U based on mask_scaling_factor
            # This step is somewhat manual because calling mask_scaling_factor requires the mask to already exist
            if self._mask_scaling_factor == "norm":
                U = math.pow(k, 1 / 4) * U
            else:
                U = math.sqrt(1 / self._mask_scaling_factor) * U
        factory_kwargs = {"device": weight.device, "dtype": weight.dtype}
        self.ortho.weight.data.copy_(U.detach().to(**factory_kwargs))
        self.register_buffer("S", S.detach().flatten().to(**factory_kwargs))


def mask_func_ste(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
    # See ProjectedLinearParametrization.__init__ for more details.
    mask = F.relu(mask)
    return (mask > 0).to(mask.dtype).detach() * mask_scaling_factor + mask - mask.detach()


def mask_func_relu(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
    # See ProjectedLinearParametrization.__init__ for more details.
    return F.relu(mask)


def mask_func_none(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
    # See ProjectedLinearParametrization.__init__ for more details.
    return mask


def mask_sparsity(mask: torch.Tensor, threshold: float = 0.0) -> int:
    """Simple util function to compute the number of non-zero elements of a mask, where an element is considered
    non-zero if its value is strictly greater than `threshold`."""
    return torch.count_nonzero(mask > threshold).item()


def new_linear_from_mask(module: nn.Linear, dim_select: torch.Tensor, column_select=True) -> nn.Linear:
    """
    Creates a new linear layer from an existing one based on a mask indicating which columns/rows to keep.

    Args:
        module: Module to be pruned.
        dim_select: Boolean tensor mask indicating which columns/rows to keep.
        column_select: Whether to prune columns (True) or rows (False) according to `dim_select`.

    Returns: Pruned module.
    """
    assert dim_select.dtype == torch.bool, "dim_select must be boolean"

    in_features, out_features = module.in_features, module.out_features
    sparsity = dim_select.sum().item()
    if column_select:
        in_features = sparsity
    else:
        out_features = sparsity
    new_module = module.__class__(
        in_features=in_features,
        out_features=out_features,
        bias=module.bias is not None,
        device=module.weight.device,
        dtype=module.weight.dtype,
    )
    weight = module.weight.data
    if column_select:
        weight = weight[:, dim_select]
    else:
        weight = weight[dim_select, :]
    new_module.weight.data.copy_(weight.detach())

    if new_module.bias is not None:
        if column_select:
            new_module.bias.data.copy_(module.bias.detach())
        else:
            # If rows are pruned, the bias needs to be pruned as well
            new_module.bias.data.copy_(module.bias[dim_select].detach())

    return new_module