File size: 14,147 Bytes
989ac62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 |
import math
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Literal
import torch
from torch import nn
from torch.nn import functional as F
from .parametrized_layer import Parametrization
from .utils import use_init_empty_weights
logger = getLogger(__name__)
class CompressionCriterion(ABC):
"""
Abstract class for compression criterion of a (target) parameter of a parametrized module.
"""
@abstractmethod
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: A tensor of any shape
Returns: A boolean mask of the same shape as `x` where `False` indicates that the entry can be removed.
"""
raise NotImplementedError
class ThresholdCriterion(CompressionCriterion):
"""
Compression criterion based on a threshold. All entries below `self.threshold` can be removed.
"""
def __init__(self, threshold: float = 0.0):
self.threshold = threshold
def __call__(self, x: torch.Tensor) -> torch.Tensor:
return x > self.threshold
class ProjectedLinearParametrization(Parametrization, ABC):
"""
Implementation of a linear layer parametrization, factorizing the weight matrix as
`weight = ortho.weight @ torch.diag(mask) @ base.weight`.
Here, `ortho` is a linear layer with orthogonal columns, `mask` represents a (binary) diagonal matrix
that can be pruned, and `base` is a linear layer (determined by the choice of `ortho`).
Any child class needs to implement `_ortho_init` which creates `ortho`. Based on this, `mask` and `base` are
initialized such that the original weight matrix is obtained at initialization.
`mask` corresponds to the only target parameter of this parametrization. Pruning it will result in
a low-rank matrix representation of the parametrized linear module.
"""
base_class = nn.Linear
def __init__(
self,
mask_func: Literal["ste", "relu", "none"] = "ste",
mask_scaling_factor: float | str = "norm",
compression_criterion: CompressionCriterion = ThresholdCriterion(),
):
"""
Args:
mask_func: A function applied to the mask parameter in each forward pass implementing
custom functionalities. Available options: ["ste", "relu", "none"].
"ste" means using a straight-through estimator, i.e., in the forward pass, `mask` is binarized, which
is ignored in the backward pass. Before `mask` passed through a ReLU activation.
"relu" means that `mask` is passed through a ReLU activation.
"none" means that `mask` is not modified.
mask_scaling_factor: Conceptually, `mask` is initialized with ones, but rescaling to a smaller value
can vastly improve the training speed. `mask_scaling_factor` specifies this rescaling factor.
The rescaling should be compensated by scaling `ortho` accordingly in `self._ortho_init`.
If `mask_scaling_factor='norm'`, the scaling factor is chosen such that `mask` has unit L2 norm
(note that this can lead to a different behavior in model tuning than for a fixed factor
when some target parameters have different number of elements).
compression_criterion: `CompressionCriterion` to be used in `self.reset_target_params(mode="compress")`.
"""
super().__init__()
self.mask_func = {
"ste": mask_func_ste,
"relu": mask_func_relu,
"none": mask_func_none,
}[mask_func]
self._mask_scaling_factor = mask_scaling_factor
self.compression_criterion = compression_criterion
def _forward(self, x: torch.Tensor) -> torch.Tensor:
# This implementation avoids an explicit materalization of `weight`.
x = self.base(x)
x = self.mask_func(self.mask, self.mask_scaling_factor) * x
x = self.ortho(x)
return x
def _weight(self) -> torch.Tensor:
# Compute the original weight matrix, don't use this in forward pass for efficiency reasons
mask = self.mask_func(self.mask, self.mask_scaling_factor)
return self.ortho.weight @ torch.diag(mask) @ self.base.weight
def _bias(self) -> torch.Tensor | None:
return self.ortho.bias
def _initialize(self, base_module: base_class) -> None:
factory_kwargs = {"device": base_module.weight.device, "dtype": base_module.weight.dtype}
in_dim, out_dim = base_module.in_features, base_module.out_features
proj_dim = min(in_dim, out_dim) # infer mask (bottleneck) dimension
# Initialize ortho layer ....
self.add_module(
"ortho",
nn.Linear(in_features=proj_dim, out_features=out_dim, bias=base_module.bias is not None, **factory_kwargs),
)
self._ortho_init(base_module.weight)
if base_module.bias is not None:
# It is important that ortho carries the bias (and not base) because ortho is used to compute the final
# output of the forward pass
self.ortho.bias.data.copy_(base_module.bias.data)
# ... and compute the base layer based on the choice of ortho (this only works of ortho has orthogonal columns)
base = base_module.__class__(in_features=in_dim, out_features=proj_dim, bias=False, **factory_kwargs)
base.weight.data.copy_(self.ortho.weight.data.T @ base_module.weight.data)
self.add_module("base", base)
# Creating (tunable) mask parameter ...
self.register_parameter("mask", torch.nn.Parameter(torch.ones(proj_dim, **factory_kwargs)))
# ... and rescale mask properly in a separate step
# (because reset_target_params calls mask_scaling_factor, which in turn may require mask to already exist)
self.reset_target_params()
@abstractmethod
def _ortho_init(self, weight: torch.Tensor) -> None:
"""
Initialize ortho layer. Must be implemented by child class.
Args:
weight: Weight matrix of the original linear layer module.
"""
raise NotImplementedError
def get_target_params(self) -> dict[str, torch.nn.Parameter]:
return {"mask": self.mask}
@property
def mask_scaling_factor(self) -> float:
if self._mask_scaling_factor == "norm":
# Choose scaling factor such that mask has unit L2 norm.
# Note: mask already needs to exist at this point to infer its shape.
self._mask_scaling_factor = 1 / math.sqrt(self.mask.numel())
return self._mask_scaling_factor
elif isinstance(self._mask_scaling_factor, float):
return self._mask_scaling_factor
else:
raise ValueError(f"Invalid mask_scaling_factor: {self._mask_scaling_factor}")
@property
def in_features(self) -> int:
return self.base.in_features
@property
def out_features(self) -> int:
return self.ortho.out_features
def reset_target_params(self, mode: Literal["full", "nonzero", "compress"] = "full") -> None:
with torch.no_grad():
if mode == "full":
# Scale mask values properly by self.mask_scaling_factor
self.mask.data = torch.ones_like(self.mask.data) * self.mask_scaling_factor
elif mode == "nonzero":
# Scale mask values properly by self.mask_scaling_factor
self.mask.data[self.mask.data > 0] = 1.0 * self.mask_scaling_factor
self.mask.data[self.mask.data < 0] = 0.0
elif mode == "compress":
if self.compression_criterion is None:
logger.warning("Compression criterion is not set. No op...")
return
# Select entries of parameter mask that should be kept
dim_select = self.compression_criterion(self.mask)
# Create and register compressed layers and mask
new_base = new_linear_from_mask(self.base, dim_select, column_select=False)
new_ortho = new_linear_from_mask(self.ortho, dim_select, column_select=True)
new_mask = self.mask[dim_select].clone().detach()
del self.mask, self.base, self.ortho
self.register_module("base", new_base)
self.register_module("ortho", new_ortho)
self.register_parameter("mask", nn.Parameter(new_mask))
else:
raise ValueError(f"Invalid mode: {mode}")
def get_num_params(self, compressed: bool = False, target_params: dict[str, torch.Tensor] | None = None) -> int:
if not compressed:
# Compute number of parameters for full linear layer
num_params = self.in_features * self.out_features
if self.bias is not None:
num_params += self.out_features
return num_params
else:
# Compute number of mask values that could be discarded by self.reset_target_params(mode="compress") ...
if target_params is not None:
sparsity = mask_sparsity(target_params["mask"] != 0.0, threshold=0.0)
else:
sparsity = mask_sparsity(self.mask)
# ... and compute the (hypothetical) number of parameters for a compressed module.
num_params = self.in_features * sparsity + sparsity * self.out_features
if self.bias is not None:
num_params += self.out_features
# If the number of parameters for the compressed module would be larger than the number of parameters
# for the full module, return the latter because we can always unparametrize to the original module if
# compression would not be effective.
num_params = min(self.get_num_params(compressed=False), num_params)
return num_params
class SVDLinearParametrization(ProjectedLinearParametrization):
"""
Implementation of a linear layer parametrization using SVD decomposition.
If the SVD of weight is U * S * V^T, then `ortho.weight = U` and `base.weight = S * V^T`.
As base is computed automatically by `_initialize`, `_ortho_init` only needs to compute U and
scale it properly with `mask_scaling_factor`. The singular values S are buffered just in case they are needed
in the tuning process.
"""
def _ortho_init(self, weight: torch.Tensor) -> None:
k = min(weight.shape[0], weight.shape[1])
if use_init_empty_weights.get():
# Check if the init_empty_weights context is active which avoids a (costly) SVD computation and just
# initializes U and S as empty tensors. They are loaded later from a pretrained model.
logger.debug("Parametrizing with empty weights.")
U = torch.empty(weight.shape[0], k)
S = torch.empty(k, 1)
else:
# Detaching is important to avoid memory leaks. torch.linalg.svd only works with float32.
U, S, _ = torch.linalg.svd(weight.detach().float(), full_matrices=False)
# Rescaling U based on mask_scaling_factor
# This step is somewhat manual because calling mask_scaling_factor requires the mask to already exist
if self._mask_scaling_factor == "norm":
U = math.pow(k, 1 / 4) * U
else:
U = math.sqrt(1 / self._mask_scaling_factor) * U
factory_kwargs = {"device": weight.device, "dtype": weight.dtype}
self.ortho.weight.data.copy_(U.detach().to(**factory_kwargs))
self.register_buffer("S", S.detach().flatten().to(**factory_kwargs))
def mask_func_ste(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
# See ProjectedLinearParametrization.__init__ for more details.
mask = F.relu(mask)
return (mask > 0).to(mask.dtype).detach() * mask_scaling_factor + mask - mask.detach()
def mask_func_relu(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
# See ProjectedLinearParametrization.__init__ for more details.
return F.relu(mask)
def mask_func_none(mask: torch.Tensor, mask_scaling_factor: float) -> torch.Tensor:
# See ProjectedLinearParametrization.__init__ for more details.
return mask
def mask_sparsity(mask: torch.Tensor, threshold: float = 0.0) -> int:
"""Simple util function to compute the number of non-zero elements of a mask, where an element is considered
non-zero if its value is strictly greater than `threshold`."""
return torch.count_nonzero(mask > threshold).item()
def new_linear_from_mask(module: nn.Linear, dim_select: torch.Tensor, column_select=True) -> nn.Linear:
"""
Creates a new linear layer from an existing one based on a mask indicating which columns/rows to keep.
Args:
module: Module to be pruned.
dim_select: Boolean tensor mask indicating which columns/rows to keep.
column_select: Whether to prune columns (True) or rows (False) according to `dim_select`.
Returns: Pruned module.
"""
assert dim_select.dtype == torch.bool, "dim_select must be boolean"
in_features, out_features = module.in_features, module.out_features
sparsity = dim_select.sum().item()
if column_select:
in_features = sparsity
else:
out_features = sparsity
new_module = module.__class__(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
)
weight = module.weight.data
if column_select:
weight = weight[:, dim_select]
else:
weight = weight[dim_select, :]
new_module.weight.data.copy_(weight.detach())
if new_module.bias is not None:
if column_select:
new_module.bias.data.copy_(module.bias.detach())
else:
# If rows are pruned, the bias needs to be pruned as well
new_module.bias.data.copy_(module.bias[dim_select].detach())
return new_module
|