ImageConductor / peft /utils /merge_utils.py
Yw22's picture
init demo
d711508
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import List, Literal
import torch
def reshape_weight_task_tensors(task_tensors, weights):
"""
Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
Args:
task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
weights (`torch.Tensor`): The tensor to be reshaped.
Returns:
`torch.Tensor`: The reshaped tensor.
"""
new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
weights = weights.view(new_shape)
return weights
def magnitude_based_pruning(tensor: torch.Tensor, density: float) -> torch.Tensor:
"""
Prune the smallest values of the task tensors and retain the top-k values based on the specified fraction
`density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
Returns:
`torch.Tensor`: The tensor with the pruned weights.
"""
mask = torch.zeros_like(tensor).reshape(-1)
k = int(density * tensor.numel())
top_k = torch.topk(tensor.abs().reshape(-1), k=k, largest=True)
mask[top_k[1]] = 1
return tensor * mask.reshape(tensor.shape)
def random_pruning(tensor: torch.Tensor, density: float, rescale: bool) -> torch.Tensor:
"""
Prune random values based on the specified fraction `density`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
Returns:
`torch.Tensor`: The pruned tensor.
"""
mask = torch.bernoulli(torch.full_like(input=tensor, fill_value=density))
pruned_tensor = tensor * mask
if rescale:
torch.div(input=pruned_tensor, other=density)
return pruned_tensor
def prune(
tensor: torch.Tensor, density: float, method: Literal["magnitude", "random"], rescale: bool = False
) -> torch.Tensor:
"""
Prune the values of task tensors based on the `method`.
Args:
tensor (`torch.Tensor`):The tensor to prune.
density (`float`):The fraction of values to preserve. Should be in [0,1].
method (`str`):The method to use to prune. Should be one of ["magnitude", "random"].
rescale (`bool`):Whether to rescale the result to preserve the expected value of the original tensor.
Returns:
`torch.Tensor`: The pruned tensor.
"""
if density >= 1:
warnings.warn(f"The density {density} is greater than or equal to 1, no pruning will be performed.")
return tensor
elif density < 0:
raise ValueError(f"Density should be >= 0, got {density}")
if method == "magnitude":
return magnitude_based_pruning(tensor, density)
elif method == "random":
return random_pruning(tensor, density, rescale=rescale)
else:
raise ValueError(f"Unknown method {method}")
def calculate_majority_sign_mask(
tensor: torch.Tensor, method: Literal["total", "frequency"] = "total"
) -> torch.Tensor:
"""
Get the mask of the majority sign across the task tensors. Task tensors are stacked on dimension 0.
Args:
tensor (`torch.Tensor`):The tensor to get the mask from.
method (`str`):The method to use to get the mask. Should be one of ["total", "frequency"].
Returns:
`torch.Tensor`: The majority sign mask.
"""
sign = tensor.sign()
if method == "total":
sign_magnitude = tensor.sum(dim=0)
elif method == "frequency":
sign_magnitude = sign.sum(dim=0)
else:
raise RuntimeError(f'Unimplemented mask method "{method}"')
majority_sign = torch.where(sign_magnitude >= 0, 1, -1)
return sign == majority_sign
def disjoint_merge(task_tensors: torch.Tensor, majority_sign_mask: torch.Tensor) -> torch.Tensor:
"""
Merge the task tensors using disjoint merge.
Args:
task_tensors (`torch.Tensor`):The task tensors to merge.
majority_sign_mask (`torch.Tensor`):The mask of the majority sign across the task tensors.
Returns:
`torch.Tensor`: The merged tensor.
"""
mixed_task_tensors = (task_tensors * majority_sign_mask).sum(dim=0)
num_params_preserved = majority_sign_mask.sum(dim=0)
return mixed_task_tensors / torch.clamp(num_params_preserved, min=1.0)
def task_arithmetic(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
"""
Merge the task tensors using `task arithmetic`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
Returns:
`torch.Tensor`: The merged tensor.
"""
task_tensors = torch.stack(task_tensors, dim=0)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
mixed_task_tensors = weighted_task_tensors.sum(dim=0)
return mixed_task_tensors
def magnitude_prune(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor:
"""
Merge the task tensors using `task arithmetic`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
density (`float`): The fraction of values to preserve. Should be in [0,1].
Returns:
`torch.Tensor`: The merged tensor.
"""
# sparsify
task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors]
task_tensors = torch.stack(task_tensors, dim=0)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
mixed_task_tensors = weighted_task_tensors.sum(dim=0)
return mixed_task_tensors
def ties(
task_tensors: List[torch.Tensor],
weights: torch.Tensor,
density: float,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> torch.Tensor:
"""
Merge the task tensors using `ties`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
density (`float`):The fraction of values to preserve. Should be in [0,1].
majority_sign_method (`str`):
The method to use to get the majority sign mask. Should be one of ["total", "frequency"].
Returns:
`torch.Tensor`: The merged tensor.
"""
# sparsify
task_tensors = [prune(tensor, density, method="magnitude") for tensor in task_tensors]
task_tensors = torch.stack(task_tensors, dim=0)
# Elect Sign
majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
# Disjoint Merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors
def dare_linear(task_tensors: List[torch.Tensor], weights: torch.Tensor, density: float) -> torch.Tensor:
"""
Merge the task tensors using `dare linear`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
density (`float`):The fraction of values to preserve. Should be in [0,1].
Returns:
`torch.Tensor`: The merged tensor.
"""
# sparsify
task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors]
task_tensors = torch.stack(task_tensors, dim=0)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
mixed_task_tensors = weighted_task_tensors.sum(dim=0)
return mixed_task_tensors
def dare_ties(
task_tensors: List[torch.Tensor],
weights: torch.Tensor,
density: float,
majority_sign_method: Literal["total", "frequency"] = "total",
) -> torch.Tensor:
"""
Merge the task tensors using `dare ties`.
Args:
task_tensors(`List[torch.Tensor]`):The task tensors to merge.
weights (`torch.Tensor`):The weights of the task tensors.
density (`float`):The fraction of values to preserve. Should be in [0,1].
majority_sign_method (`str`):
The method to use to get the majority sign mask. Should be one of ["total", "frequency"].
Returns:
`torch.Tensor`: The merged tensor.
"""
# sparsify
task_tensors = [prune(tensor, density, method="random", rescale=True) for tensor in task_tensors]
task_tensors = torch.stack(task_tensors, dim=0)
# Elect Sign
majority_sign_mask = calculate_majority_sign_mask(task_tensors, method=majority_sign_method)
# weighted task tensors
weights = reshape_weight_task_tensors(task_tensors, weights)
weighted_task_tensors = task_tensors * weights
# Disjoint Merge
mixed_task_tensors = disjoint_merge(weighted_task_tensors, majority_sign_mask)
return mixed_task_tensors